Beispiel #1
0
def main():
    model_file = 'model/ssd300_vgg16_short.pb'
    graph = load_graph(model_file)

    with tf.Session(graph=graph) as sess:
        image_input = sess.graph.get_tensor_by_name(
            'import/define_input/image_input:0')
        result = sess.graph.get_tensor_by_name("import/result/result:0")

        image_path = 'demo/test.jpg'
        img = cv2.imread(image_path)
        img = np.float32(img)
        img = cv2.resize(img, (300, 300))
        img = np.expand_dims(img, axis=0)
        print('image_input', image_input)
        print('img', type(img), img.shape, img[0][1][1])
        enc_boxes = sess.run(result, feed_dict={image_input: img})
        print('enc_boxes', type(enc_boxes), len(enc_boxes), type(enc_boxes[0]),
              enc_boxes[0].shape)
        print('detect_result_[0][0]', enc_boxes[0][0])

        lid2name = {
            0: 'Aeroplane',
            1: 'Bicycle',
            2: 'Bird',
            3: 'Boat',
            4: 'Bottle',
            5: 'Bus',
            6: 'Car',
            7: 'Cat',
            8: 'Chair',
            9: 'Cow',
            10: 'Diningtable',
            11: 'Dog',
            12: 'Horse',
            13: 'Motorbike',
            14: 'Person',
            15: 'Pottedplant',
            16: 'Sheep',
            17: 'Sofa',
            18: 'Train',
            19: 'Tvmonitor'
        }
        preset = get_preset_by_name('vgg300')
        anchors = get_anchors_for_preset(preset)
        print('anchors', type(anchors))
        boxes = decode_boxes(enc_boxes[0], anchors, 0.5, lid2name, None)
        boxes = suppress_overlaps(boxes)[:200]
        print('boxes', boxes)

        img = cv2.imread(image_path)
        for box in boxes:
            color = (31, 119, 180)
            draw_box(img, box[1], color)

            box_data = '{} {} {} {} {} {}\n'.format(
                box[1].label, box[1].labelid, box[1].center.x, box[1].center.y,
                box[1].size.w, box[1].size.h)
            print('box_data', box_data)
        cv2.imwrite(image_path + '_out.jpg', img)
Beispiel #2
0
 def initialize(self):
     self.anchors = get_anchors_for_preset(self.preset)
     self.vheight = len(self.anchors)
     self.vwidth = self.num_classes + 5  # background class + location offsets
     self.img_size = Size(1000, 1000)
     self.anchors_arr = anchors2array(self.anchors, self.img_size)
     self.initialized = True
def run(files, img_input_tensor, result_tensor, data, sess, batch_size = 32):
    output_imgs = []

    preset = data['preset']
    colors = data['colors']
    lid2name = data['lid2name']
    anchors = get_anchors_for_preset(preset)

    for i in range(0, len(files), batch_size):
        batch_names = files[i:i+batch_size]
        batch_imgs = []
        batch = []
        for f in batch_names:
            img = cv2.imread(f)
            batch_imgs.append(img)
            img = cv2.resize(img, (300, 300))
            batch.append(img)

        batch = np.array(batch)
        feed = {img_input_tensor: batch}
        enc_boxes = sess.run(result_tensor, feed_dict=feed)

        for i in range(len(batch_names)):
            boxes = decode_boxes(enc_boxes[i], anchors, 0.5, lid2name, None)
            boxes = suppress_overlaps(boxes)[:200]
            name = os.path.basename(batch_names[i])

            for box in boxes:
                draw_box(batch_imgs[i], box[1], colors[box[1].label])
            output_imgs.append(batch_imgs[i])

                #with open(os.path.join(args.output_dir, name+'.txt'), 'w') as f:
                #    for box in boxes:
                #        draw_box(batch_imgs[i], box[1], colors[box[1].label])

                        #box_data = '{} {} {} {} {} {}\n'.format(box[1].label,
                        #    box[1].labelid, box[1].center.x, box[1].center.y,
                        #    box[1].size.w, box[1].size.h)
                        #f.write(box_data)
                
                #cv2.imwrite(os.path.join(args.output_dir, name),
                #            batch_imgs[i])
    return output_imgs
def main():
    #---------------------------------------------------------------------------
    # Parse the commandline
    #---------------------------------------------------------------------------
    parser = argparse.ArgumentParser(description='Train the SSD')
    parser.add_argument('--name',
                        default='test-combined-run1-6Dec',
                        help='project name')
    parser.add_argument('--data-dir', default='VOC', help='data directory')
    parser.add_argument('--vgg-dir',
                        default='vgg_graph',
                        help='directory for the VGG-16 model')
    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        help='number of training epochs')
    parser.add_argument('--batch-size', type=int, default=8, help='batch size')
    parser.add_argument('--batch_size_val',
                        type=int,
                        default=1,
                        help='batch size val')
    parser.add_argument('--tensorboard-dir',
                        default="tb-combined-run1-6Dec",
                        help='name of the tensorboard data directory')
    parser.add_argument('--checkpoint-interval',
                        type=int,
                        default=10,
                        help='checkpoint interval')
    parser.add_argument('--lr-values',
                        type=str,
                        default='0.00001; 0.000001;0.00001',
                        help='learning rate values')
    parser.add_argument('--lr-boundaries',
                        type=str,
                        default='18300;36600',
                        help='learning rate chage boundaries (in batches)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum for the optimizer')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=1e-6,
                        help='L2 normalization factor')
    parser.add_argument('--continue-training',
                        type=str2bool,
                        default='False',
                        help='continue training from the latest checkpoint')
    parser.add_argument('--num-workers',
                        type=int,
                        default=6,
                        help='number of parallel generators')

    args = parser.parse_args()

    print('[i] Project name:         ', args.name)
    print('[i] Data directory:       ', args.data_dir)
    print('[i] VGG directory:        ', args.vgg_dir)
    print('[i] # epochs:             ', args.epochs)
    print('[i] Batch size:           ', args.batch_size)
    print('[i] Batch size val:       ', args.batch_size_val)
    print('[i] Tensorboard directory:', args.tensorboard_dir)
    print('[i] Checkpoint interval:  ', args.checkpoint_interval)
    print('[i] Learning rate values: ', args.lr_values)
    print('[i] Learning rate boundaries: ', args.lr_boundaries)
    print('[i] Momentum:             ', args.momentum)
    print('[i] Weight decay:         ', args.weight_decay)
    print('[i] Continue:             ', args.continue_training)
    print('[i] Number of workers:    ', args.num_workers)

    #---------------------------------------------------------------------------
    # Find an existing checkpoint
    #---------------------------------------------------------------------------
    start_epoch = 0
    if args.continue_training:
        state = tf.train.get_checkpoint_state(args.name)
        if state is None:
            print('[!] No network state found in ' + args.name)
            return 1

        ckpt_paths = state.all_model_checkpoint_paths
        if not ckpt_paths:
            print('[!] No network state found in ' + args.name)
            return 1

        last_epoch = None
        checkpoint_file = None
        for ckpt in ckpt_paths:
            ckpt_num = os.path.basename(ckpt).split('.')[0][1:]
            try:
                ckpt_num = int(ckpt_num)
            except ValueError:
                continue
            if last_epoch is None or last_epoch < ckpt_num:
                last_epoch = ckpt_num
                checkpoint_file = ckpt

        if checkpoint_file is None:
            print('[!] No checkpoints found, cannot continue!')
            return 1

        metagraph_file = checkpoint_file + '.meta'

        if not os.path.exists(metagraph_file):
            print('[!] Cannot find metagraph', metagraph_file)
            return 1
        start_epoch = last_epoch

    #---------------------------------------------------------------------------
    # Create a project directory
    #---------------------------------------------------------------------------
    else:
        try:
            if not os.path.exists(args.name):
                print('[i] Creating directory {}...'.format(args.name))
                os.makedirs(args.name)
        except (IOError) as e:
            print('[!]', str(e))
            return 1

    print('[i] Starting at epoch:', start_epoch + 1)

    #---------------------------------------------------------------------------
    # Configure the training data
    #---------------------------------------------------------------------------
    print('[i] Configuring the training data...')
    try:
        td = TrainingData(args.data_dir)
        print('[i] # training samples:   ', td.num_train)
        print('[i] # validation samples: ', td.num_valid)
        print('[i] # classes:            ', td.num_classes)
        print('[i] Image size:           ', td.preset.image_size)
    except (AttributeError, RuntimeError) as e:
        print('[!] Unable to load training data:', str(e))
        return 1

    #---------------------------------------------------------------------------
    # Create the network
    #---------------------------------------------------------------------------
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    with tf.Session(config=config) as sess:
        print('[i] Creating the model...')
        n_train_batches = int(math.ceil(td.num_train / args.batch_size))
        n_valid_batches = int(math.ceil(td.num_valid / args.batch_size_val))

        global_step = None
        if start_epoch == 0:
            lr_values = args.lr_values.split(';')
            try:
                lr_values = [float(x) for x in lr_values]
            except ValueError:
                print('[!] Learning rate values must be floats')
                sys.exit(1)

            lr_boundaries = args.lr_boundaries.split(';')
            try:
                lr_boundaries = [int(x) for x in lr_boundaries]
            except ValueError:
                print('[!] Learning rate boundaries must be ints')
                sys.exit(1)

            ret = compute_lr(lr_values, lr_boundaries)
            learning_rate, global_step = ret

        net = SSDVGG(sess, td.preset)
        if start_epoch != 0:
            net.build_from_metagraph(metagraph_file, checkpoint_file)
            net.build_optimizer_from_metagraph()
        else:
            net.build_from_vgg(args.vgg_dir, td.num_classes)
            net.build_optimizer(learning_rate=learning_rate,
                                global_step=global_step,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum)

        initialize_uninitialized_variables(sess)

        #-----------------------------------------------------------------------
        # Create various helpers
        #-----------------------------------------------------------------------
        summary_writer = tf.summary.FileWriter(args.tensorboard_dir,
                                               sess.graph)
        saver = tf.train.Saver(max_to_keep=20)

        scores_list = []
        class_scores_list = []
        precision_list = []
        recall_list = []
        f1_list = []
        iou_list = []

        scores_list_training = []
        class_scores_list_training = []
        precision_list_training = []
        recall_list_training = []
        f1_list_training = []
        iou_list_training = []

        avg_loss_per_epoch = []
        avg_scores_per_epoch = []
        avg_iou_per_epoch = []
        avg_iou_per_epoch_training = []

        anchors = get_anchors_for_preset(td.preset)
        training_ap_calc = APCalculator()
        validation_ap_calc = APCalculator()

        #-----------------------------------------------------------------------
        # Summaries
        #-----------------------------------------------------------------------
        restore = start_epoch != 0

        training_ap = PrecisionSummary(sess, summary_writer, 'training',
                                       td.lname2id.keys(), restore)
        validation_ap = PrecisionSummary(sess, summary_writer, 'validation',
                                         td.lname2id.keys(), restore)

        training_imgs = ImageSummary(sess, summary_writer, 'training',
                                     td.label_colors, restore)
        validation_imgs = ImageSummary(sess, summary_writer, 'validation',
                                       td.label_colors, restore)
        training_loss = LossSummary(sess, summary_writer, 'training',
                                    td.num_train, restore)
        validation_loss = LossSummary(sess, summary_writer, 'validation',
                                      td.num_valid, restore)

        #-----------------------------------------------------------------------
        # Get the initial snapshot of the network
        #-----------------------------------------------------------------------
        net_summary_ops = net.build_summaries(restore)
        if start_epoch == 0:
            net_summary = sess.run(net_summary_ops)
            summary_writer.add_summary(net_summary, 0)
        summary_writer.flush()

        #-----------------------------------------------------------------------
        # Cycle through the epoch
        #-----------------------------------------------------------------------
        print('[i] Training...')
        for e in range(start_epoch, args.epochs):
            training_imgs_samples = []
            validation_imgs_samples = []
            #-------------------------------------------------------------------
            # Train
            #-------------------------------------------------------------------
            generator = td.train_generator(args.batch_size, args.num_workers)
            description = '[i] Train {:>2}/{}'.format(e + 1, args.epochs)

            for x, y, gt_boxes, img_seg_gt, imgseg_gt_to_compare in tqdm(
                    generator,
                    total=n_train_batches,
                    desc=description,
                    unit='batches'):

                if len(training_imgs_samples) < 3:
                    saved_images = np.copy(x[:3])

                feed = {
                    net.image_input: x,
                    net.labels: y,
                    net.label_seg_gt: img_seg_gt
                }
                output_seg, result, loss_batch, _ = sess.run(
                    [net.logits_seg, net.result, net.losses, net.optimizer],
                    feed_dict=feed)

                output_image = np.array(output_seg[0, :, :, :])
                output_seg_rev = reverse_one_hot(output_image)
                output_seg_output = colour_code_segmentation(output_seg_rev)

                if math.isnan(loss_batch['total']):
                    print('[!] total loss is NaN.')

                training_loss.add(loss_batch, x.shape[0])

                if e == 0: continue

                for i in range(result.shape[0]):
                    boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name)
                    boxes = suppress_overlaps(boxes)
                    training_ap_calc.add_detections(gt_boxes[i], boxes)

                    if len(training_imgs_samples) < 3:
                        training_imgs_samples.append((saved_images[i], boxes))

            #-------------------------------------------------------------------
            # Validate
            #-------------------------------------------------------------------
            generator = td.valid_generator(args.batch_size_val,
                                           args.num_workers)
            description = '[i] Valid {:>2}/{}'.format(e + 1, args.epochs)

            if not os.path.isdir("%s/%04d" % ("Combined-fcn8-run1", e)):
                os.makedirs("%s/%04d" % ("Combined-fcn8-run1", e))

            counter_val = 0

            target = open("%s/%04d/val_scores.csv" % ("Combined-fcn8-run1", e),
                          'w')
            target.write(
                "val_name, avg_accuracy, precision, recall, f1 score, mean iou, %s\n"
                % str(counter_val))

            for x, y, gt_boxes, img_seg_gt, imgseg_gt_to_compare in tqdm(
                    generator,
                    total=n_valid_batches,
                    desc=description,
                    unit='batches'):

                gt_rev_onehot = reverse_one_hot(
                    one_hot_encode(np.squeeze(imgseg_gt_to_compare)))

                feed = {
                    net.image_input: x,
                    net.labels: y,
                    net.label_seg_gt: img_seg_gt
                }
                result, output_seg, loss_batch = sess.run(
                    [net.result, net.logits_seg, net.losses], feed_dict=feed)

                output_image = np.array(output_seg[0, :, :, :])
                output_seg_rev = reverse_one_hot(output_image)
                output_seg_output = colour_code_segmentation(output_seg_rev)

                accuracy, class_accuracies, prec, rec, f1, iou = evaluate_segmentation(
                    pred=output_seg_rev, label=gt_rev_onehot, num_classes=21)

                filename = str(counter_val)
                target.write("%s, %f, %f, %f, %f, %f" %
                             (filename, accuracy, prec, rec, f1, iou))
                for item in class_accuracies:
                    target.write(", %f" % (item))
                target.write("\n")

                scores_list.append(accuracy)
                class_scores_list.append(class_accuracies)
                precision_list.append(prec)
                recall_list.append(rec)
                f1_list.append(f1)
                iou_list.append(iou)

                original_gt = colour_code_segmentation(gt_rev_onehot)

                cv2.imwrite(
                    "%s/%04d/%s_gt.png" % ("Combined-fcn8-run1", e, filename),
                    original_gt)
                cv2.imwrite(
                    "%s/%04d/%s_pred.png" %
                    ("Combined-fcn8-run1", e, filename), output_seg_output)

                counter_val += 1

                validation_loss.add(loss_batch, x.shape[0])

                if e == 0: continue

                for i in range(result.shape[0]):
                    boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name)
                    boxes = suppress_overlaps(boxes)
                    validation_ap_calc.add_detections(gt_boxes[i], boxes)

                    if len(validation_imgs_samples) < 3:
                        validation_imgs_samples.append((np.copy(x[i]), boxes))

            target.close()
            #-------------------------------------------------------------------
            #Write summaries
            #-------------------------------------------------------------------
            avg_score = np.mean(scores_list)
            avg_scores_per_epoch.append(avg_score)
            avg_precision = np.mean(precision_list)
            avg_recall = np.mean(recall_list)
            avg_f1 = np.mean(f1_list)
            avg_iou = np.mean(iou_list)
            avg_iou_per_epoch.append(avg_iou)

            print("\nAverage validation accuracy for epoch # %04d = %f" %
                  (e, avg_score))
            print("Validation precision = ", avg_precision)
            print("Validation recall = ", avg_recall)
            print("Validation F1 score = ", avg_f1)
            print("Validation IoU score = ", avg_iou)

            training_loss.push(e + 1)
            validation_loss.push(e + 1)

            net_summary = sess.run(net_summary_ops)
            summary_writer.add_summary(net_summary, e + 1)

            APs = training_ap_calc.compute_aps()
            mAP = APs2mAP(APs)
            training_ap.push(e + 1, mAP, APs)

            APs = validation_ap_calc.compute_aps()
            mAP = APs2mAP(APs)
            validation_ap.push(e + 1, mAP, APs)

            training_ap_calc.clear()
            validation_ap_calc.clear()

            training_imgs.push(e + 1, training_imgs_samples)
            validation_imgs.push(e + 1, validation_imgs_samples)

            summary_writer.flush()

            #-------------------------------------------------------------------
            # Save a checktpoint
            #-------------------------------------------------------------------
            if (e + 1) % args.checkpoint_interval == 0:
                checkpoint = '{}/e{}.ckpt'.format(args.name, e + 1)
                saver.save(sess, checkpoint)
                print('[i] Checkpoint saved:', checkpoint)

        checkpoint = '{}/final.ckpt'.format(args.name)
        saver.save(sess, checkpoint)
        print('[i] Checkpoint saved:', checkpoint)

    return 0
Beispiel #5
0
def main():
    #---------------------------------------------------------------------------
    # Parse the commandline
    #---------------------------------------------------------------------------
    parser = argparse.ArgumentParser(description='SSD inference')
    parser.add_argument("files", nargs="*")
    parser.add_argument('--model', default='model300.pb', help='model file')
    parser.add_argument('--training-data',
                        default='training-data-300.pkl',
                        help='training data')
    parser.add_argument('--output-dir',
                        default='test-out',
                        help='output directory')
    parser.add_argument('--batch-size', type=int, default=1, help='batch size')
    args = parser.parse_args()

    #---------------------------------------------------------------------------
    # Print parameters
    #---------------------------------------------------------------------------
    print('[i] Model:         ', args.model)
    print('[i] Training data: ', args.training_data)
    print('[i] Output dir:    ', args.output_dir)
    print('[i] Batch size:    ', args.batch_size)

    #---------------------------------------------------------------------------
    # Load the graph and the training data
    #---------------------------------------------------------------------------
    graph_def = tf.GraphDef()
    with open(args.model, 'rb') as f:
        serialized = f.read()
        graph_def.ParseFromString(serialized)

    with open(args.training_data, 'rb') as f:
        data = pickle.load(f)
        preset = data['preset']
        colors = data['colors']
        lid2name = data['lid2name']
        anchors = get_anchors_for_preset(preset)

    #---------------------------------------------------------------------------
    # Create the output directory
    #---------------------------------------------------------------------------
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    #---------------------------------------------------------------------------
    # Run the detections in batches
    #---------------------------------------------------------------------------
    with tf.Session() as sess:
        tf.import_graph_def(graph_def, name='detector')
        img_input = sess.graph.get_tensor_by_name('detector/image_input:0')
        result = sess.graph.get_tensor_by_name('detector/result/result:0')

        files = sys.argv[1:]

        for i in tqdm(range(0, len(files), args.batch_size)):
            batch_names = files[i:i + args.batch_size]
            batch_imgs = []
            batch = []
            for f in batch_names:
                img = cv2.imread(f)
                batch_imgs.append(img)
                img = cv2.resize(img, (300, 300))
                batch.append(img)

            batch = np.array(batch)
            feed = {img_input: batch}
            enc_boxes = sess.run(result, feed_dict=feed)

            for i in range(len(batch_names)):
                boxes = decode_boxes(enc_boxes[i], anchors, 0.5, lid2name,
                                     None)
                boxes = suppress_overlaps(boxes)[:200]
                name = os.path.basename(batch_names[i])

                with open(os.path.join(args.output_dir, name + '.txt'),
                          'w') as f:
                    for box in boxes:
                        draw_box(batch_imgs[i], box[1], colors[box[1].label])

                        box_data = '{} {} {} {} {} {}\n'.format(
                            box[1].label, box[1].labelid, box[1].center.x,
                            box[1].center.y, box[1].size.w, box[1].size.h)
                        f.write(box_data)

                cv2.imwrite(os.path.join(args.output_dir, name), batch_imgs[i])
Beispiel #6
0
def main():
    #---------------------------------------------------------------------------
    # Parse the commandline
    #---------------------------------------------------------------------------
    parser = argparse.ArgumentParser(description='Train the SSD')
    parser.add_argument('--name', default='test',
                        help='project name')
    parser.add_argument('--data-dir', default='pascal-voc',
                        help='data directory')
    parser.add_argument('--vgg-dir', default='vgg_graph',
                        help='directory for the VGG-16 model')
    parser.add_argument('--epochs', type=int, default=200,
                        help='number of training epochs')
    parser.add_argument('--batch-size', type=int, default=8,
                        help='batch size')
    parser.add_argument('--tensorboard-dir', default="tb",
                        help='name of the tensorboard data directory')
    parser.add_argument('--checkpoint-interval', type=int, default=5,
                        help='checkpoint interval')
    parser.add_argument('--lr-values', type=str, default='0.00075;0.0001;0.00001',
                        help='learning rate values')
    parser.add_argument('--lr-boundaries', type=str, default='320000;400000',
                        help='learning rate chage boundaries (in batches)')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum for the optimizer')
    parser.add_argument('--weight-decay', type=float, default=0.0005,
                        help='L2 normalization factor')
    parser.add_argument('--continue-training', type=str2bool, default='False',
                        help='continue training from the latest checkpoint')
    parser.add_argument('--num-workers', type=int, default=mp.cpu_count(),
                        help='number of parallel generators')

    args = parser.parse_args()

    print('[i] Project name:         ', args.name)
    print('[i] Data directory:       ', args.data_dir)
    print('[i] VGG directory:        ', args.vgg_dir)
    print('[i] # epochs:             ', args.epochs)
    print('[i] Batch size:           ', args.batch_size)
    print('[i] Tensorboard directory:', args.tensorboard_dir)
    print('[i] Checkpoint interval:  ', args.checkpoint_interval)
    print('[i] Learning rate values: ', args.lr_values)
    print('[i] Learning rate boundaries: ', args.lr_boundaries)
    print('[i] Momentum:             ', args.momentum)
    print('[i] Weight decay:         ', args.weight_decay)
    print('[i] Continue:             ', args.continue_training)
    print('[i] Number of workers:    ', args.num_workers)

    #---------------------------------------------------------------------------
    # Find an existing checkpoint
    #---------------------------------------------------------------------------
    start_epoch = 0
    if args.continue_training:
        state = tf.train.get_checkpoint_state(args.name)
        if state is None:
            print('[!] No network state found in ' + args.name)
            return 1

        ckpt_paths = state.all_model_checkpoint_paths
        if not ckpt_paths:
            print('[!] No network state found in ' + args.name)
            return 1

        last_epoch = None
        checkpoint_file = None
        for ckpt in ckpt_paths:
            ckpt_num = os.path.basename(ckpt).split('.')[0][1:]
            try:
                ckpt_num = int(ckpt_num)
            except ValueError:
                continue
            if last_epoch is None or last_epoch < ckpt_num:
                last_epoch = ckpt_num
                checkpoint_file = ckpt

        if checkpoint_file is None:
            print('[!] No checkpoints found, cannot continue!')
            return 1

        metagraph_file = checkpoint_file + '.meta'

        if not os.path.exists(metagraph_file):
            print('[!] Cannot find metagraph', metagraph_file)
            return 1
        start_epoch = last_epoch

    #---------------------------------------------------------------------------
    # Create a project directory
    #---------------------------------------------------------------------------
    else:
        try:
            print('[i] Creating directory {}...'.format(args.name))
            os.makedirs(args.name)
        except (IOError) as e:
            print('[!]', str(e))
            return 1

    print('[i] Starting at epoch:    ', start_epoch+1)

    #---------------------------------------------------------------------------
    # Configure the training data
    #---------------------------------------------------------------------------
    print('[i] Configuring the training data...')
    try:
        td = TrainingData(args.data_dir)
        print('[i] # training samples:   ', td.num_train)
        print('[i] # validation samples: ', td.num_valid)
        print('[i] # classes:            ', td.num_classes)
        print('[i] Image size:           ', td.preset.image_size)
    except (AttributeError, RuntimeError) as e:
        print('[!] Unable to load training data:', str(e))
        return 1

    #---------------------------------------------------------------------------
    # Create the network
    #---------------------------------------------------------------------------
    with tf.Session() as sess:
        print('[i] Creating the model...')
        n_train_batches = int(math.ceil(td.num_train/args.batch_size))
        n_valid_batches = int(math.ceil(td.num_valid/args.batch_size))

        global_step = None
        if start_epoch == 0:
            lr_values = args.lr_values.split(';')
            try:
                lr_values = [float(x) for x in lr_values]
            except ValueError:
                print('[!] Learning rate values must be floats')
                sys.exit(1)

            lr_boundaries = args.lr_boundaries.split(';')
            try:
                lr_boundaries = [int(x) for x in lr_boundaries]
            except ValueError:
                print('[!] Learning rate boundaries must be ints')
                sys.exit(1)

            ret = compute_lr(lr_values, lr_boundaries)
            learning_rate, global_step = ret

        net = SSDVGG(sess, td.preset)
        if start_epoch != 0:
            net.build_from_metagraph(metagraph_file, checkpoint_file)
            net.build_optimizer_from_metagraph()
        else:
            net.build_from_vgg(args.vgg_dir, td.num_classes)
            net.build_optimizer(learning_rate=learning_rate,
                                global_step=global_step,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum)

        initialize_uninitialized_variables(sess)

        #-----------------------------------------------------------------------
        # Create various helpers
        #-----------------------------------------------------------------------
        summary_writer = tf.summary.FileWriter(args.tensorboard_dir,
                                               sess.graph)
        saver = tf.train.Saver(max_to_keep=20)

        anchors = get_anchors_for_preset(td.preset)
        training_ap_calc = APCalculator()
        validation_ap_calc = APCalculator()

        #-----------------------------------------------------------------------
        # Summaries
        #-----------------------------------------------------------------------
        restore = start_epoch != 0

        training_ap = PrecisionSummary(sess, summary_writer, 'training',
                                       td.lname2id.keys(), restore)
        validation_ap = PrecisionSummary(sess, summary_writer, 'validation',
                                         td.lname2id.keys(), restore)

        training_imgs = ImageSummary(sess, summary_writer, 'training',
                                     td.label_colors, restore)
        validation_imgs = ImageSummary(sess, summary_writer, 'validation',
                                       td.label_colors, restore)

        training_loss = LossSummary(sess, summary_writer, 'training',
                                    td.num_train, restore)
        validation_loss = LossSummary(sess, summary_writer, 'validation',
                                      td.num_valid, restore)

        #-----------------------------------------------------------------------
        # Get the initial snapshot of the network
        #-----------------------------------------------------------------------
        net_summary_ops = net.build_summaries(restore)
        if start_epoch == 0:
            net_summary = sess.run(net_summary_ops)
            summary_writer.add_summary(net_summary, 0)
        summary_writer.flush()

        #-----------------------------------------------------------------------
        # Cycle through the epoch
        #-----------------------------------------------------------------------
        print('[i] Training...')
        for e in range(start_epoch, args.epochs):
            training_imgs_samples = []
            validation_imgs_samples = []

            #-------------------------------------------------------------------
            # Train
            #-------------------------------------------------------------------
            generator = td.train_generator(args.batch_size, args.num_workers)
            description = '[i] Train {:>2}/{}'.format(e+1, args.epochs)
            for x, y, gt_boxes in tqdm(generator, total=n_train_batches,
                                       desc=description, unit='batches'):

                if len(training_imgs_samples) < 3:
                    saved_images = np.copy(x[:3])

                feed = {net.image_input: x,
                        net.labels: y}
                result, loss_batch, _ = sess.run([net.result, net.losses,
                                                  net.optimizer],
                                                 feed_dict=feed)

                if math.isnan(loss_batch['confidence']):
                    print('[!] Confidence loss is NaN.')

                training_loss.add(loss_batch, x.shape[0])

                if e == 0: continue

                for i in range(result.shape[0]):
                    boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name)
                    boxes = suppress_overlaps(boxes)
                    training_ap_calc.add_detections(gt_boxes[i], boxes)

                    if len(training_imgs_samples) < 3:
                        training_imgs_samples.append((saved_images[i], boxes))

            #-------------------------------------------------------------------
            # Validate
            #-------------------------------------------------------------------
            generator = td.valid_generator(args.batch_size, args.num_workers)
            description = '[i] Valid {:>2}/{}'.format(e+1, args.epochs)

            for x, y, gt_boxes in tqdm(generator, total=n_valid_batches,
                                       desc=description, unit='batches'):
                feed = {net.image_input: x,
                        net.labels: y}
                result, loss_batch = sess.run([net.result, net.losses],
                                              feed_dict=feed)

                validation_loss.add(loss_batch,  x.shape[0])

                if e == 0: continue

                for i in range(result.shape[0]):
                    boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name)
                    boxes = suppress_overlaps(boxes)
                    validation_ap_calc.add_detections(gt_boxes[i], boxes)

                    if len(validation_imgs_samples) < 3:
                        validation_imgs_samples.append((np.copy(x[i]), boxes))

            #-------------------------------------------------------------------
            # Write summaries
            #-------------------------------------------------------------------
            training_loss.push(e+1)
            validation_loss.push(e+1)

            net_summary = sess.run(net_summary_ops)
            summary_writer.add_summary(net_summary, e+1)

            APs = training_ap_calc.compute_aps()
            mAP = APs2mAP(APs)
            training_ap.push(e+1, mAP, APs)

            APs = validation_ap_calc.compute_aps()
            mAP = APs2mAP(APs)
            validation_ap.push(e+1, mAP, APs)

            training_ap_calc.clear()
            validation_ap_calc.clear()

            training_imgs.push(e+1, training_imgs_samples)
            validation_imgs.push(e+1, validation_imgs_samples)

            summary_writer.flush()

            #-------------------------------------------------------------------
            # Save a checktpoint
            #-------------------------------------------------------------------
            if (e+1) % args.checkpoint_interval == 0:
                checkpoint = '{}/e{}.ckpt'.format(args.name, e+1)
                saver.save(sess, checkpoint)
                print('[i] Checkpoint saved:', checkpoint)

        checkpoint = '{}/final.ckpt'.format(args.name)
        saver.save(sess, checkpoint)
        print('[i] Checkpoint saved:', checkpoint)

    return 0
def main():
    checkpoint_file = 'model/e25.ckpt'
    metagraph_file = checkpoint_file + '.meta'
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)

        preset = get_preset_by_name('vgg300')
        anchors = get_anchors_for_preset(preset)
        net = SSDVGG(sess, preset)
        net.build_from_metagraph(metagraph_file, checkpoint_file)

        #for tensor in tf.get_default_graph().as_graph_def().node: print(tensor.name)

        image_path = 'demo/test.jpg'
        img = cv2.imread(image_path)
        img = np.float32(img)
        img = cv2.resize(img, (300, 300))
        img = np.expand_dims(img, axis=0)
        print('image_input', net.image_input)
        print('img', type(img), img.shape, img[0][1][1])
        #exit()
        enc_boxes = sess.run(net.result, feed_dict={net.image_input: img})
        print('enc_boxes', type(enc_boxes), len(enc_boxes), type(enc_boxes[0]),
              enc_boxes[0].shape)

        lid2name = {
            0: 'Aeroplane',
            1: 'Bicycle',
            2: 'Bird',
            3: 'Boat',
            4: 'Bottle',
            5: 'Bus',
            6: 'Car',
            7: 'Cat',
            8: 'Chair',
            9: 'Cow',
            10: 'Diningtable',
            11: 'Dog',
            12: 'Horse',
            13: 'Motorbike',
            14: 'Person',
            15: 'Pottedplant',
            16: 'Sheep',
            17: 'Sofa',
            18: 'Train',
            19: 'Tvmonitor'
        }
        print('anchors', type(anchors))
        boxes = decode_boxes(enc_boxes[0], anchors, 0.5, lid2name, None)
        boxes = suppress_overlaps(boxes)[:200]

        img = cv2.imread(image_path)
        for box in boxes:
            color = (31, 119, 180)
            draw_box(img, box[1], color)

            box_data = '{} {} {} {} {} {}\n'.format(
                box[1].label, box[1].labelid, box[1].center.x, box[1].center.y,
                box[1].size.w, box[1].size.h)
            print('box_data', box_data)
        cv2.imwrite(image_path + '_out.jpg', img)
def main():
    # Parse commandline
    parser = argparse.ArgumentParser(description='SSD inference')
    parser.add_argument("files", nargs="*")

    parser.add_argument('--checkpoint-dir',
                        default='pascal-voc/checkpoints',
                        help='project name')
    parser.add_argument('--checkpoint',
                        type=int,
                        default=-1,
                        help='checkpoint to restore; -1 is the most recent')

    parser.add_argument('--data-source',
                        default="pascal-voc",
                        help='Use test files from the data source')
    parser.add_argument('--data-dir',
                        default='pascal-voc',
                        help='Use test files from the data source')
    parser.add_argument('--training-data',
                        default='pascal-voc/training-data.pkl',
                        help='Information about parameters used for training')
    parser.add_argument('--output-dir',
                        default='pascal-voc/annotated/train',
                        help='directory for the resulting images')
    parser.add_argument('--annotate',
                        type=str2bool,
                        default='False',
                        help="Annotate the data samples")
    parser.add_argument('--dump-predictions',
                        type=str2bool,
                        default='False',
                        help="Dump raw predictions")
    parser.add_argument('--summary',
                        type=str2bool,
                        default='True',
                        help='dump the detections in Pascal VOC format')
    parser.add_argument('--compute-stats',
                        type=str2bool,
                        default='True',
                        help="Compute the mAP stats")
    parser.add_argument('--sample',
                        default='train',
                        choices=['train', 'valid'])

    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        help='batch size')
    parser.add_argument('--threshold',
                        type=float,
                        default=0.5,
                        help='confidence threshold')

    args = parser.parse_args()

    # Print parameters
    print('[i] Checkpoint directory: ', args.checkpoint_dir)

    print('[i] Data source:          ', args.data_source)
    print('[i] Data directory:       ', args.data_dir)
    print('[i] Training data:        ', args.training_data)
    print('[i] Output directory:     ', args.output_dir)
    print('[i] Annotate:             ', args.annotate)
    print('[i] Dump predictions:     ', args.dump_predictions)
    print('[i] Summary:              ', args.summary)
    print('[i] Compute state:        ', args.compute_stats)
    print('[i] Sample:               ', args.sample)

    print('[i] Batch size:           ', args.batch_size)
    print('[i] Threshold:            ', args.threshold)

    # Check if we can get the checkpoint
    state = tf.train.get_checkpoint_state(args.checkpoint_dir)
    if state is None:
        print('[!] No network state found in ' + args.checkpoint_dir)
        return 1

    try:
        checkpoint_file = state.all_model_checkpoint_paths[args.checkpoint]
    except IndexError:
        print('[!] Cannot find checkpoint ' + str(args.checkpoint_file))
        return 1

    metagraph_file = checkpoint_file + '.meta'

    if not os.path.exists(metagraph_file):
        print('[!] Cannot find metagraph ' + metagraph_file)
        return 1

    # Load the training data parameters
    try:
        with open(args.training_data, 'rb') as f:
            data = pickle.load(f)
        preset = data['preset']
        colors = data['colors']
        lid2name = data['lid2name']
        image_size = preset.image_size
        anchors = get_anchors_for_preset(preset)
    except (FileNotFoundError, IOError, KeyError) as e:
        print('[!] Unable to load training data:', str(e))
        return 1

    # Load the samples according to data source and sample type
    try:
        if args.sample == 'train':
            with open(args.data_dir + '/train-samples.pkl', 'rb') as f:
                samples = pickle.load(f)
        else:
            with open(args.data_dir + '/valid-samples.pkl', 'rb') as f:
                samples = pickle.load(f)
        num_samples = len(samples)
        print('[i] # samples:         ', num_samples)
    except (ImportError, AttributeError, RuntimeError) as e:
        print('[!] Unable to load data source:', str(e))
        return 1

    # Create a list of files to analyse and make sure that the output directory exists
    files = []

    for sample in samples:
        files.append(sample.filename)

    files = list(filter(lambda x: os.path.exists(x), files))
    if files:
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)

    # Print model and dataset stats
    print('[i] Network checkpoint:', checkpoint_file)
    print('[i] Metagraph file:    ', metagraph_file)
    print('[i] Image size:        ', image_size)
    print('[i] Number of files:   ', len(files))

    # Create the network
    if args.compute_stats:
        ap_calc = APCalculator()

    if args.summary:
        summary = PascalSummary()

    with tf.Session() as sess:
        print('[i] Creating the model...')
        net = SSDVGG(sess, preset)
        net.build_from_metagraph(metagraph_file, checkpoint_file)

        # Process the images
        generator = sample_generator(files, image_size, args.batch_size)
        n_sample_batches = int(math.ceil(len(files) / args.batch_size))
        description = '[i] Processing samples'

        for x, idxs in tqdm(generator,
                            total=n_sample_batches,
                            desc=description,
                            unit='batches'):
            feed = {net.image_input: x, net.keep_prob: 1}
            enc_boxes = sess.run(net.result, feed_dict=feed)

            # Process the predictions
            for i in range(enc_boxes.shape[0]):
                boxes = decode_boxes(enc_boxes[i], anchors, args.threshold,
                                     lid2name, None)
                boxes = suppress_overlaps(boxes)[:200]
                filename = files[idxs[i]]
                basename = os.path.basename(filename)

                # Annotate samples
                if args.annotate:
                    img = cv2.imread(filename)
                    for box in boxes:
                        draw_box(img, box[1], colors[box[1].label])
                    fn = args.output_dir + '/images/' + basename
                    cv2.imwrite(fn, img)

                # Dump the predictions
                if args.dump_predictions:
                    raw_fn = args.output_dir + '/' + basename + '.npy'
                    np.save(raw_fn, enc_boxes[i])

                # Add predictions to the stats calculator and to the summary
                if args.compute_stats:
                    ap_calc.add_detections(samples[idxs[i]].boxes, boxes)

                if args.summary:
                    summary.add_detections(filename, boxes)

    # Compute and print the stats
    if args.compute_stats:
        aps = ap_calc.compute_aps()
        for k, v in aps.items():
            print('[i] AP [{0}]: {1:.3f}'.format(k, v))
        print('[i] mAP: {0:.3f}'.format(APs2mAP(aps)))

    # Write the summary files
    if args.summary:
        summary.write_summary(args.output_dir + "/summaries")

    print('[i] All done.')
    return 0
Beispiel #9
0
def main():
    # Parse the commandline
    parser = argparse.ArgumentParser(description='SSD inference')
    parser.add_argument('--model', default='./pascal-voc/models/e225-SSD300-VGG16-PASCALVOC.tflite', help='model file')
    parser.add_argument('--training-data', default='./pascal-voc/training-data.pkl', help='training data')
    parser.add_argument("--input-dir", default='./test/in', help='input directory')
    parser.add_argument('--output-dir', default='./test/out', help='output directory')
    parser.add_argument('--batch-size', type=int, default=1, help='batch size')
    args = parser.parse_args()

    # Print parameters
    print('[i] Model:         ', args.model)
    print('[i] Training data: ', args.training_data)
    print('[i] Input dir:     ', args.input_dir)
    print('[i] Output dir:    ', args.output_dir)
    print('[i] Batch size:    ', args.batch_size)

    # Load the training data
    with open(args.training_data, 'rb') as f:
        data = pickle.load(f)
        preset = data['preset']
        colors = data['colors']
        lid2name = data['lid2name']
        anchors = get_anchors_for_preset(preset)

    # Get the input images
    images = os.listdir(args.input_dir)
    images = ["%s/%s" % (args.input_dir, image) for image in images]

    # Create the output directory
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Load the TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path=args.model)
    interpreter.allocate_tensors()

    # Run the detections in batches
    for i in tqdm(range(0, len(images), args.batch_size)):
        batch_names = images[i:i+args.batch_size]
        batch_imgs = []
        batch = []
        for f in batch_names:
            img = cv2.imread(f)
            batch_imgs.append(img)
            img = cv2.resize(img, (300, 300))
            batch.append(img.astype(np.float32))

        batch = np.array(batch)

        # Get input and output tensors.
        input_details = interpreter.get_input_details()
        interpreter.set_tensor(input_details[0]['index'], batch)
        interpreter.invoke()

        output_details = interpreter.get_output_details()
        enc_boxes = interpreter.get_tensor(output_details[0]['index'])

        for i in range(len(batch_names)):
            boxes = decode_boxes(enc_boxes[i], anchors, 0.5, lid2name, None)
            boxes = suppress_overlaps(boxes)[:200]
            name = os.path.basename(batch_names[i])
            meta = {}
            for j, box in enumerate(boxes):
                draw_box(batch_imgs[i], box[1], colors[box[1].label])
                box_data = {}
                box_data['Label'] = box[1].label,
                box_data['LabelID'] = str(box[1].labelid)
                box_data['Center'] = [box[1].center.x, box[1].center.y]
                box_data['Size'] = [box[1].size.w, box[1].size.h]
                box_data['Confidence'] = str(box[0])
                meta["prediction_%s" % (j+1)] = box_data
            with open(os.path.join(args.output_dir, name+'.json'), 'w') as f:
                json.dump(meta, f, indent=4)

            cv2.imwrite(os.path.join(args.output_dir, name), batch_imgs[i])
def main():
    # Parse the commandline
    parser = argparse.ArgumentParser(description='SSD inference')
    parser.add_argument(
        '--model',
        default='./pascal-voc/frozen/e225-SSD300-VGG16-PASCALVOC.pb',
        help='model file')
    parser.add_argument('--training-data',
                        default='./pascal-voc/training-data.pkl',
                        help='training data')
    parser.add_argument("--input-dir",
                        default='./test/in',
                        help='input directory')
    parser.add_argument('--output-dir',
                        default='./test/out',
                        help='output directory')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        help='batch size')
    args = parser.parse_args()

    # Print parameters
    print('[i] Model:         ', args.model)
    print('[i] Training data: ', args.training_data)
    print('[i] Input dir:     ', args.input_dir)
    print('[i] Output dir:    ', args.output_dir)
    print('[i] Batch size:    ', args.batch_size)

    # Load the graph and the training data
    graph_def = tf.GraphDef()
    with open(args.model, 'rb') as f:
        serialized = f.read()
        graph_def.ParseFromString(serialized)

    with open(args.training_data, 'rb') as f:
        data = pickle.load(f)
        preset = data['preset']
        colors = data['colors']
        lid2name = data['lid2name']
        anchors = get_anchors_for_preset(preset)

    # Get the input images
    images = os.listdir(args.input_dir)
    images = ["%s/%s" % (args.input_dir, image) for image in images]

    # Create the output directory
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Run the detections in batches
    with tf.Session() as sess:
        tf.import_graph_def(graph_def, name='detector')
        img_input = sess.graph.get_tensor_by_name('detector/image_input:0')
        result = sess.graph.get_tensor_by_name('detector/result/result:0')

        for i in tqdm(range(0, len(images), args.batch_size)):
            batch_names = images[i:i + args.batch_size]
            batch_imgs = []
            batch = []
            for f in batch_names:
                img = cv2.imread(f)
                batch_imgs.append(img)
                img = cv2.resize(img, (300, 300))
                batch.append(img)

            batch = np.array(batch)
            feed = {img_input: batch}
            enc_boxes = sess.run(result, feed_dict=feed)

            for i in range(len(batch_names)):
                boxes = decode_boxes(enc_boxes[i], anchors, 0.5, lid2name,
                                     None)
                boxes = suppress_overlaps(boxes)[:200]
                name = os.path.basename(batch_names[i])
                meta = {}
                for j, box in enumerate(boxes):
                    draw_box(batch_imgs[i], box[1], colors[box[1].label])
                    box_data = {}
                    box_data['Label'] = box[1].label,
                    box_data['LabelID'] = str(box[1].labelid)
                    box_data['Center'] = [box[1].center.x, box[1].center.y]
                    box_data['Size'] = [box[1].size.w, box[1].size.h]
                    box_data['Confidence'] = str(box[0])
                    meta["prediction_%s" % (j + 1)] = box_data
                with open(os.path.join(args.output_dir, name + '.json'),
                          'w') as f:
                    json.dump(meta, f, indent=4)

                cv2.imwrite(os.path.join(args.output_dir, name), batch_imgs[i])