コード例 #1
0
def main():
    #---------------------------------------------------------------------------
    # Parse the commandline
    #---------------------------------------------------------------------------
    parser = argparse.ArgumentParser(description='Train the SSD')
    parser.add_argument('--name', default='test',
                        help='project name')
    parser.add_argument('--data-dir', default='pascal-voc',
                        help='data directory')
    parser.add_argument('--vgg-dir', default='vgg_graph',
                        help='directory for the VGG-16 model')
    parser.add_argument('--epochs', type=int, default=200,
                        help='number of training epochs')
    parser.add_argument('--batch-size', type=int, default=8,
                        help='batch size')
    parser.add_argument('--tensorboard-dir', default="tb",
                        help='name of the tensorboard data directory')
    parser.add_argument('--checkpoint-interval', type=int, default=5,
                        help='checkpoint interval')
    parser.add_argument('--lr-values', type=str, default='0.00075;0.0001;0.00001',
                        help='learning rate values')
    parser.add_argument('--lr-boundaries', type=str, default='320000;400000',
                        help='learning rate chage boundaries (in batches)')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum for the optimizer')
    parser.add_argument('--weight-decay', type=float, default=0.0005,
                        help='L2 normalization factor')
    parser.add_argument('--continue-training', type=str2bool, default='False',
                        help='continue training from the latest checkpoint')
    parser.add_argument('--num-workers', type=int, default=mp.cpu_count(),
                        help='number of parallel generators')

    args = parser.parse_args()

    print('[i] Project name:         ', args.name)
    print('[i] Data directory:       ', args.data_dir)
    print('[i] VGG directory:        ', args.vgg_dir)
    print('[i] # epochs:             ', args.epochs)
    print('[i] Batch size:           ', args.batch_size)
    print('[i] Tensorboard directory:', args.tensorboard_dir)
    print('[i] Checkpoint interval:  ', args.checkpoint_interval)
    print('[i] Learning rate values: ', args.lr_values)
    print('[i] Learning rate boundaries: ', args.lr_boundaries)
    print('[i] Momentum:             ', args.momentum)
    print('[i] Weight decay:         ', args.weight_decay)
    print('[i] Continue:             ', args.continue_training)
    print('[i] Number of workers:    ', args.num_workers)

    #---------------------------------------------------------------------------
    # Find an existing checkpoint
    #---------------------------------------------------------------------------
    start_epoch = 0
    if args.continue_training:
        state = tf.train.get_checkpoint_state(args.name)
        if state is None:
            print('[!] No network state found in ' + args.name)
            return 1

        ckpt_paths = state.all_model_checkpoint_paths
        if not ckpt_paths:
            print('[!] No network state found in ' + args.name)
            return 1

        last_epoch = None
        checkpoint_file = None
        for ckpt in ckpt_paths:
            ckpt_num = os.path.basename(ckpt).split('.')[0][1:]
            try:
                ckpt_num = int(ckpt_num)
            except ValueError:
                continue
            if last_epoch is None or last_epoch < ckpt_num:
                last_epoch = ckpt_num
                checkpoint_file = ckpt

        if checkpoint_file is None:
            print('[!] No checkpoints found, cannot continue!')
            return 1

        metagraph_file = checkpoint_file + '.meta'

        if not os.path.exists(metagraph_file):
            print('[!] Cannot find metagraph', metagraph_file)
            return 1
        start_epoch = last_epoch

    #---------------------------------------------------------------------------
    # Create a project directory
    #---------------------------------------------------------------------------
    else:
        try:
            print('[i] Creating directory {}...'.format(args.name))
            os.makedirs(args.name)
        except (IOError) as e:
            print('[!]', str(e))
            return 1

    print('[i] Starting at epoch:    ', start_epoch+1)

    #---------------------------------------------------------------------------
    # Configure the training data
    #---------------------------------------------------------------------------
    print('[i] Configuring the training data...')
    try:
        td = TrainingData(args.data_dir)
        print('[i] # training samples:   ', td.num_train)
        print('[i] # validation samples: ', td.num_valid)
        print('[i] # classes:            ', td.num_classes)
        print('[i] Image size:           ', td.preset.image_size)
    except (AttributeError, RuntimeError) as e:
        print('[!] Unable to load training data:', str(e))
        return 1

    #---------------------------------------------------------------------------
    # Create the network
    #---------------------------------------------------------------------------
    with tf.Session() as sess:
        print('[i] Creating the model...')
        n_train_batches = int(math.ceil(td.num_train/args.batch_size))
        n_valid_batches = int(math.ceil(td.num_valid/args.batch_size))

        global_step = None
        if start_epoch == 0:
            lr_values = args.lr_values.split(';')
            try:
                lr_values = [float(x) for x in lr_values]
            except ValueError:
                print('[!] Learning rate values must be floats')
                sys.exit(1)

            lr_boundaries = args.lr_boundaries.split(';')
            try:
                lr_boundaries = [int(x) for x in lr_boundaries]
            except ValueError:
                print('[!] Learning rate boundaries must be ints')
                sys.exit(1)

            ret = compute_lr(lr_values, lr_boundaries)
            learning_rate, global_step = ret

        net = SSDVGG(sess, td.preset)
        if start_epoch != 0:
            net.build_from_metagraph(metagraph_file, checkpoint_file)
            net.build_optimizer_from_metagraph()
        else:
            net.build_from_vgg(args.vgg_dir, td.num_classes)
            net.build_optimizer(learning_rate=learning_rate,
                                global_step=global_step,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum)

        initialize_uninitialized_variables(sess)

        #-----------------------------------------------------------------------
        # Create various helpers
        #-----------------------------------------------------------------------
        summary_writer = tf.summary.FileWriter(args.tensorboard_dir,
                                               sess.graph)
        saver = tf.train.Saver(max_to_keep=20)

        anchors = get_anchors_for_preset(td.preset)
        training_ap_calc = APCalculator()
        validation_ap_calc = APCalculator()

        #-----------------------------------------------------------------------
        # Summaries
        #-----------------------------------------------------------------------
        restore = start_epoch != 0

        training_ap = PrecisionSummary(sess, summary_writer, 'training',
                                       td.lname2id.keys(), restore)
        validation_ap = PrecisionSummary(sess, summary_writer, 'validation',
                                         td.lname2id.keys(), restore)

        training_imgs = ImageSummary(sess, summary_writer, 'training',
                                     td.label_colors, restore)
        validation_imgs = ImageSummary(sess, summary_writer, 'validation',
                                       td.label_colors, restore)

        training_loss = LossSummary(sess, summary_writer, 'training',
                                    td.num_train, restore)
        validation_loss = LossSummary(sess, summary_writer, 'validation',
                                      td.num_valid, restore)

        #-----------------------------------------------------------------------
        # Get the initial snapshot of the network
        #-----------------------------------------------------------------------
        net_summary_ops = net.build_summaries(restore)
        if start_epoch == 0:
            net_summary = sess.run(net_summary_ops)
            summary_writer.add_summary(net_summary, 0)
        summary_writer.flush()

        #-----------------------------------------------------------------------
        # Cycle through the epoch
        #-----------------------------------------------------------------------
        print('[i] Training...')
        for e in range(start_epoch, args.epochs):
            training_imgs_samples = []
            validation_imgs_samples = []

            #-------------------------------------------------------------------
            # Train
            #-------------------------------------------------------------------
            generator = td.train_generator(args.batch_size, args.num_workers)
            description = '[i] Train {:>2}/{}'.format(e+1, args.epochs)
            for x, y, gt_boxes in tqdm(generator, total=n_train_batches,
                                       desc=description, unit='batches'):

                if len(training_imgs_samples) < 3:
                    saved_images = np.copy(x[:3])

                feed = {net.image_input: x,
                        net.labels: y}
                result, loss_batch, _ = sess.run([net.result, net.losses,
                                                  net.optimizer],
                                                 feed_dict=feed)

                if math.isnan(loss_batch['confidence']):
                    print('[!] Confidence loss is NaN.')

                training_loss.add(loss_batch, x.shape[0])

                if e == 0: continue

                for i in range(result.shape[0]):
                    boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name)
                    boxes = suppress_overlaps(boxes)
                    training_ap_calc.add_detections(gt_boxes[i], boxes)

                    if len(training_imgs_samples) < 3:
                        training_imgs_samples.append((saved_images[i], boxes))

            #-------------------------------------------------------------------
            # Validate
            #-------------------------------------------------------------------
            generator = td.valid_generator(args.batch_size, args.num_workers)
            description = '[i] Valid {:>2}/{}'.format(e+1, args.epochs)

            for x, y, gt_boxes in tqdm(generator, total=n_valid_batches,
                                       desc=description, unit='batches'):
                feed = {net.image_input: x,
                        net.labels: y}
                result, loss_batch = sess.run([net.result, net.losses],
                                              feed_dict=feed)

                validation_loss.add(loss_batch,  x.shape[0])

                if e == 0: continue

                for i in range(result.shape[0]):
                    boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name)
                    boxes = suppress_overlaps(boxes)
                    validation_ap_calc.add_detections(gt_boxes[i], boxes)

                    if len(validation_imgs_samples) < 3:
                        validation_imgs_samples.append((np.copy(x[i]), boxes))

            #-------------------------------------------------------------------
            # Write summaries
            #-------------------------------------------------------------------
            training_loss.push(e+1)
            validation_loss.push(e+1)

            net_summary = sess.run(net_summary_ops)
            summary_writer.add_summary(net_summary, e+1)

            APs = training_ap_calc.compute_aps()
            mAP = APs2mAP(APs)
            training_ap.push(e+1, mAP, APs)

            APs = validation_ap_calc.compute_aps()
            mAP = APs2mAP(APs)
            validation_ap.push(e+1, mAP, APs)

            training_ap_calc.clear()
            validation_ap_calc.clear()

            training_imgs.push(e+1, training_imgs_samples)
            validation_imgs.push(e+1, validation_imgs_samples)

            summary_writer.flush()

            #-------------------------------------------------------------------
            # Save a checktpoint
            #-------------------------------------------------------------------
            if (e+1) % args.checkpoint_interval == 0:
                checkpoint = '{}/e{}.ckpt'.format(args.name, e+1)
                saver.save(sess, checkpoint)
                print('[i] Checkpoint saved:', checkpoint)

        checkpoint = '{}/final.ckpt'.format(args.name)
        saver.save(sess, checkpoint)
        print('[i] Checkpoint saved:', checkpoint)

    return 0
コード例 #2
0
def main():
    #---------------------------------------------------------------------------
    # Parse the commandline
    #---------------------------------------------------------------------------
    parser = argparse.ArgumentParser(description='Train the SSD')
    parser.add_argument('--name',
                        default='test-combined-run1-6Dec',
                        help='project name')
    parser.add_argument('--data-dir', default='VOC', help='data directory')
    parser.add_argument('--vgg-dir',
                        default='vgg_graph',
                        help='directory for the VGG-16 model')
    parser.add_argument('--epochs',
                        type=int,
                        default=300,
                        help='number of training epochs')
    parser.add_argument('--batch-size', type=int, default=8, help='batch size')
    parser.add_argument('--batch_size_val',
                        type=int,
                        default=1,
                        help='batch size val')
    parser.add_argument('--tensorboard-dir',
                        default="tb-combined-run1-6Dec",
                        help='name of the tensorboard data directory')
    parser.add_argument('--checkpoint-interval',
                        type=int,
                        default=10,
                        help='checkpoint interval')
    parser.add_argument('--lr-values',
                        type=str,
                        default='0.00001; 0.000001;0.00001',
                        help='learning rate values')
    parser.add_argument('--lr-boundaries',
                        type=str,
                        default='18300;36600',
                        help='learning rate chage boundaries (in batches)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum for the optimizer')
    parser.add_argument('--weight-decay',
                        type=float,
                        default=1e-6,
                        help='L2 normalization factor')
    parser.add_argument('--continue-training',
                        type=str2bool,
                        default='False',
                        help='continue training from the latest checkpoint')
    parser.add_argument('--num-workers',
                        type=int,
                        default=6,
                        help='number of parallel generators')

    args = parser.parse_args()

    print('[i] Project name:         ', args.name)
    print('[i] Data directory:       ', args.data_dir)
    print('[i] VGG directory:        ', args.vgg_dir)
    print('[i] # epochs:             ', args.epochs)
    print('[i] Batch size:           ', args.batch_size)
    print('[i] Batch size val:       ', args.batch_size_val)
    print('[i] Tensorboard directory:', args.tensorboard_dir)
    print('[i] Checkpoint interval:  ', args.checkpoint_interval)
    print('[i] Learning rate values: ', args.lr_values)
    print('[i] Learning rate boundaries: ', args.lr_boundaries)
    print('[i] Momentum:             ', args.momentum)
    print('[i] Weight decay:         ', args.weight_decay)
    print('[i] Continue:             ', args.continue_training)
    print('[i] Number of workers:    ', args.num_workers)

    #---------------------------------------------------------------------------
    # Find an existing checkpoint
    #---------------------------------------------------------------------------
    start_epoch = 0
    if args.continue_training:
        state = tf.train.get_checkpoint_state(args.name)
        if state is None:
            print('[!] No network state found in ' + args.name)
            return 1

        ckpt_paths = state.all_model_checkpoint_paths
        if not ckpt_paths:
            print('[!] No network state found in ' + args.name)
            return 1

        last_epoch = None
        checkpoint_file = None
        for ckpt in ckpt_paths:
            ckpt_num = os.path.basename(ckpt).split('.')[0][1:]
            try:
                ckpt_num = int(ckpt_num)
            except ValueError:
                continue
            if last_epoch is None or last_epoch < ckpt_num:
                last_epoch = ckpt_num
                checkpoint_file = ckpt

        if checkpoint_file is None:
            print('[!] No checkpoints found, cannot continue!')
            return 1

        metagraph_file = checkpoint_file + '.meta'

        if not os.path.exists(metagraph_file):
            print('[!] Cannot find metagraph', metagraph_file)
            return 1
        start_epoch = last_epoch

    #---------------------------------------------------------------------------
    # Create a project directory
    #---------------------------------------------------------------------------
    else:
        try:
            if not os.path.exists(args.name):
                print('[i] Creating directory {}...'.format(args.name))
                os.makedirs(args.name)
        except (IOError) as e:
            print('[!]', str(e))
            return 1

    print('[i] Starting at epoch:', start_epoch + 1)

    #---------------------------------------------------------------------------
    # Configure the training data
    #---------------------------------------------------------------------------
    print('[i] Configuring the training data...')
    try:
        td = TrainingData(args.data_dir)
        print('[i] # training samples:   ', td.num_train)
        print('[i] # validation samples: ', td.num_valid)
        print('[i] # classes:            ', td.num_classes)
        print('[i] Image size:           ', td.preset.image_size)
    except (AttributeError, RuntimeError) as e:
        print('[!] Unable to load training data:', str(e))
        return 1

    #---------------------------------------------------------------------------
    # Create the network
    #---------------------------------------------------------------------------
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.8
    with tf.Session(config=config) as sess:
        print('[i] Creating the model...')
        n_train_batches = int(math.ceil(td.num_train / args.batch_size))
        n_valid_batches = int(math.ceil(td.num_valid / args.batch_size_val))

        global_step = None
        if start_epoch == 0:
            lr_values = args.lr_values.split(';')
            try:
                lr_values = [float(x) for x in lr_values]
            except ValueError:
                print('[!] Learning rate values must be floats')
                sys.exit(1)

            lr_boundaries = args.lr_boundaries.split(';')
            try:
                lr_boundaries = [int(x) for x in lr_boundaries]
            except ValueError:
                print('[!] Learning rate boundaries must be ints')
                sys.exit(1)

            ret = compute_lr(lr_values, lr_boundaries)
            learning_rate, global_step = ret

        net = SSDVGG(sess, td.preset)
        if start_epoch != 0:
            net.build_from_metagraph(metagraph_file, checkpoint_file)
            net.build_optimizer_from_metagraph()
        else:
            net.build_from_vgg(args.vgg_dir, td.num_classes)
            net.build_optimizer(learning_rate=learning_rate,
                                global_step=global_step,
                                weight_decay=args.weight_decay,
                                momentum=args.momentum)

        initialize_uninitialized_variables(sess)

        #-----------------------------------------------------------------------
        # Create various helpers
        #-----------------------------------------------------------------------
        summary_writer = tf.summary.FileWriter(args.tensorboard_dir,
                                               sess.graph)
        saver = tf.train.Saver(max_to_keep=20)

        scores_list = []
        class_scores_list = []
        precision_list = []
        recall_list = []
        f1_list = []
        iou_list = []

        scores_list_training = []
        class_scores_list_training = []
        precision_list_training = []
        recall_list_training = []
        f1_list_training = []
        iou_list_training = []

        avg_loss_per_epoch = []
        avg_scores_per_epoch = []
        avg_iou_per_epoch = []
        avg_iou_per_epoch_training = []

        anchors = get_anchors_for_preset(td.preset)
        training_ap_calc = APCalculator()
        validation_ap_calc = APCalculator()

        #-----------------------------------------------------------------------
        # Summaries
        #-----------------------------------------------------------------------
        restore = start_epoch != 0

        training_ap = PrecisionSummary(sess, summary_writer, 'training',
                                       td.lname2id.keys(), restore)
        validation_ap = PrecisionSummary(sess, summary_writer, 'validation',
                                         td.lname2id.keys(), restore)

        training_imgs = ImageSummary(sess, summary_writer, 'training',
                                     td.label_colors, restore)
        validation_imgs = ImageSummary(sess, summary_writer, 'validation',
                                       td.label_colors, restore)
        training_loss = LossSummary(sess, summary_writer, 'training',
                                    td.num_train, restore)
        validation_loss = LossSummary(sess, summary_writer, 'validation',
                                      td.num_valid, restore)

        #-----------------------------------------------------------------------
        # Get the initial snapshot of the network
        #-----------------------------------------------------------------------
        net_summary_ops = net.build_summaries(restore)
        if start_epoch == 0:
            net_summary = sess.run(net_summary_ops)
            summary_writer.add_summary(net_summary, 0)
        summary_writer.flush()

        #-----------------------------------------------------------------------
        # Cycle through the epoch
        #-----------------------------------------------------------------------
        print('[i] Training...')
        for e in range(start_epoch, args.epochs):
            training_imgs_samples = []
            validation_imgs_samples = []
            #-------------------------------------------------------------------
            # Train
            #-------------------------------------------------------------------
            generator = td.train_generator(args.batch_size, args.num_workers)
            description = '[i] Train {:>2}/{}'.format(e + 1, args.epochs)

            for x, y, gt_boxes, img_seg_gt, imgseg_gt_to_compare in tqdm(
                    generator,
                    total=n_train_batches,
                    desc=description,
                    unit='batches'):

                if len(training_imgs_samples) < 3:
                    saved_images = np.copy(x[:3])

                feed = {
                    net.image_input: x,
                    net.labels: y,
                    net.label_seg_gt: img_seg_gt
                }
                output_seg, result, loss_batch, _ = sess.run(
                    [net.logits_seg, net.result, net.losses, net.optimizer],
                    feed_dict=feed)

                output_image = np.array(output_seg[0, :, :, :])
                output_seg_rev = reverse_one_hot(output_image)
                output_seg_output = colour_code_segmentation(output_seg_rev)

                if math.isnan(loss_batch['total']):
                    print('[!] total loss is NaN.')

                training_loss.add(loss_batch, x.shape[0])

                if e == 0: continue

                for i in range(result.shape[0]):
                    boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name)
                    boxes = suppress_overlaps(boxes)
                    training_ap_calc.add_detections(gt_boxes[i], boxes)

                    if len(training_imgs_samples) < 3:
                        training_imgs_samples.append((saved_images[i], boxes))

            #-------------------------------------------------------------------
            # Validate
            #-------------------------------------------------------------------
            generator = td.valid_generator(args.batch_size_val,
                                           args.num_workers)
            description = '[i] Valid {:>2}/{}'.format(e + 1, args.epochs)

            if not os.path.isdir("%s/%04d" % ("Combined-fcn8-run1", e)):
                os.makedirs("%s/%04d" % ("Combined-fcn8-run1", e))

            counter_val = 0

            target = open("%s/%04d/val_scores.csv" % ("Combined-fcn8-run1", e),
                          'w')
            target.write(
                "val_name, avg_accuracy, precision, recall, f1 score, mean iou, %s\n"
                % str(counter_val))

            for x, y, gt_boxes, img_seg_gt, imgseg_gt_to_compare in tqdm(
                    generator,
                    total=n_valid_batches,
                    desc=description,
                    unit='batches'):

                gt_rev_onehot = reverse_one_hot(
                    one_hot_encode(np.squeeze(imgseg_gt_to_compare)))

                feed = {
                    net.image_input: x,
                    net.labels: y,
                    net.label_seg_gt: img_seg_gt
                }
                result, output_seg, loss_batch = sess.run(
                    [net.result, net.logits_seg, net.losses], feed_dict=feed)

                output_image = np.array(output_seg[0, :, :, :])
                output_seg_rev = reverse_one_hot(output_image)
                output_seg_output = colour_code_segmentation(output_seg_rev)

                accuracy, class_accuracies, prec, rec, f1, iou = evaluate_segmentation(
                    pred=output_seg_rev, label=gt_rev_onehot, num_classes=21)

                filename = str(counter_val)
                target.write("%s, %f, %f, %f, %f, %f" %
                             (filename, accuracy, prec, rec, f1, iou))
                for item in class_accuracies:
                    target.write(", %f" % (item))
                target.write("\n")

                scores_list.append(accuracy)
                class_scores_list.append(class_accuracies)
                precision_list.append(prec)
                recall_list.append(rec)
                f1_list.append(f1)
                iou_list.append(iou)

                original_gt = colour_code_segmentation(gt_rev_onehot)

                cv2.imwrite(
                    "%s/%04d/%s_gt.png" % ("Combined-fcn8-run1", e, filename),
                    original_gt)
                cv2.imwrite(
                    "%s/%04d/%s_pred.png" %
                    ("Combined-fcn8-run1", e, filename), output_seg_output)

                counter_val += 1

                validation_loss.add(loss_batch, x.shape[0])

                if e == 0: continue

                for i in range(result.shape[0]):
                    boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name)
                    boxes = suppress_overlaps(boxes)
                    validation_ap_calc.add_detections(gt_boxes[i], boxes)

                    if len(validation_imgs_samples) < 3:
                        validation_imgs_samples.append((np.copy(x[i]), boxes))

            target.close()
            #-------------------------------------------------------------------
            #Write summaries
            #-------------------------------------------------------------------
            avg_score = np.mean(scores_list)
            avg_scores_per_epoch.append(avg_score)
            avg_precision = np.mean(precision_list)
            avg_recall = np.mean(recall_list)
            avg_f1 = np.mean(f1_list)
            avg_iou = np.mean(iou_list)
            avg_iou_per_epoch.append(avg_iou)

            print("\nAverage validation accuracy for epoch # %04d = %f" %
                  (e, avg_score))
            print("Validation precision = ", avg_precision)
            print("Validation recall = ", avg_recall)
            print("Validation F1 score = ", avg_f1)
            print("Validation IoU score = ", avg_iou)

            training_loss.push(e + 1)
            validation_loss.push(e + 1)

            net_summary = sess.run(net_summary_ops)
            summary_writer.add_summary(net_summary, e + 1)

            APs = training_ap_calc.compute_aps()
            mAP = APs2mAP(APs)
            training_ap.push(e + 1, mAP, APs)

            APs = validation_ap_calc.compute_aps()
            mAP = APs2mAP(APs)
            validation_ap.push(e + 1, mAP, APs)

            training_ap_calc.clear()
            validation_ap_calc.clear()

            training_imgs.push(e + 1, training_imgs_samples)
            validation_imgs.push(e + 1, validation_imgs_samples)

            summary_writer.flush()

            #-------------------------------------------------------------------
            # Save a checktpoint
            #-------------------------------------------------------------------
            if (e + 1) % args.checkpoint_interval == 0:
                checkpoint = '{}/e{}.ckpt'.format(args.name, e + 1)
                saver.save(sess, checkpoint)
                print('[i] Checkpoint saved:', checkpoint)

        checkpoint = '{}/final.ckpt'.format(args.name)
        saver.save(sess, checkpoint)
        print('[i] Checkpoint saved:', checkpoint)

    return 0
コード例 #3
0
def main():
    # Parse commandline
    parser = argparse.ArgumentParser(description='SSD inference')
    parser.add_argument("files", nargs="*")

    parser.add_argument('--checkpoint-dir',
                        default='pascal-voc/checkpoints',
                        help='project name')
    parser.add_argument('--checkpoint',
                        type=int,
                        default=-1,
                        help='checkpoint to restore; -1 is the most recent')

    parser.add_argument('--data-source',
                        default="pascal-voc",
                        help='Use test files from the data source')
    parser.add_argument('--data-dir',
                        default='pascal-voc',
                        help='Use test files from the data source')
    parser.add_argument('--training-data',
                        default='pascal-voc/training-data.pkl',
                        help='Information about parameters used for training')
    parser.add_argument('--output-dir',
                        default='pascal-voc/annotated/train',
                        help='directory for the resulting images')
    parser.add_argument('--annotate',
                        type=str2bool,
                        default='False',
                        help="Annotate the data samples")
    parser.add_argument('--dump-predictions',
                        type=str2bool,
                        default='False',
                        help="Dump raw predictions")
    parser.add_argument('--summary',
                        type=str2bool,
                        default='True',
                        help='dump the detections in Pascal VOC format')
    parser.add_argument('--compute-stats',
                        type=str2bool,
                        default='True',
                        help="Compute the mAP stats")
    parser.add_argument('--sample',
                        default='train',
                        choices=['train', 'valid'])

    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        help='batch size')
    parser.add_argument('--threshold',
                        type=float,
                        default=0.5,
                        help='confidence threshold')

    args = parser.parse_args()

    # Print parameters
    print('[i] Checkpoint directory: ', args.checkpoint_dir)

    print('[i] Data source:          ', args.data_source)
    print('[i] Data directory:       ', args.data_dir)
    print('[i] Training data:        ', args.training_data)
    print('[i] Output directory:     ', args.output_dir)
    print('[i] Annotate:             ', args.annotate)
    print('[i] Dump predictions:     ', args.dump_predictions)
    print('[i] Summary:              ', args.summary)
    print('[i] Compute state:        ', args.compute_stats)
    print('[i] Sample:               ', args.sample)

    print('[i] Batch size:           ', args.batch_size)
    print('[i] Threshold:            ', args.threshold)

    # Check if we can get the checkpoint
    state = tf.train.get_checkpoint_state(args.checkpoint_dir)
    if state is None:
        print('[!] No network state found in ' + args.checkpoint_dir)
        return 1

    try:
        checkpoint_file = state.all_model_checkpoint_paths[args.checkpoint]
    except IndexError:
        print('[!] Cannot find checkpoint ' + str(args.checkpoint_file))
        return 1

    metagraph_file = checkpoint_file + '.meta'

    if not os.path.exists(metagraph_file):
        print('[!] Cannot find metagraph ' + metagraph_file)
        return 1

    # Load the training data parameters
    try:
        with open(args.training_data, 'rb') as f:
            data = pickle.load(f)
        preset = data['preset']
        colors = data['colors']
        lid2name = data['lid2name']
        image_size = preset.image_size
        anchors = get_anchors_for_preset(preset)
    except (FileNotFoundError, IOError, KeyError) as e:
        print('[!] Unable to load training data:', str(e))
        return 1

    # Load the samples according to data source and sample type
    try:
        if args.sample == 'train':
            with open(args.data_dir + '/train-samples.pkl', 'rb') as f:
                samples = pickle.load(f)
        else:
            with open(args.data_dir + '/valid-samples.pkl', 'rb') as f:
                samples = pickle.load(f)
        num_samples = len(samples)
        print('[i] # samples:         ', num_samples)
    except (ImportError, AttributeError, RuntimeError) as e:
        print('[!] Unable to load data source:', str(e))
        return 1

    # Create a list of files to analyse and make sure that the output directory exists
    files = []

    for sample in samples:
        files.append(sample.filename)

    files = list(filter(lambda x: os.path.exists(x), files))
    if files:
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)

    # Print model and dataset stats
    print('[i] Network checkpoint:', checkpoint_file)
    print('[i] Metagraph file:    ', metagraph_file)
    print('[i] Image size:        ', image_size)
    print('[i] Number of files:   ', len(files))

    # Create the network
    if args.compute_stats:
        ap_calc = APCalculator()

    if args.summary:
        summary = PascalSummary()

    with tf.Session() as sess:
        print('[i] Creating the model...')
        net = SSDVGG(sess, preset)
        net.build_from_metagraph(metagraph_file, checkpoint_file)

        # Process the images
        generator = sample_generator(files, image_size, args.batch_size)
        n_sample_batches = int(math.ceil(len(files) / args.batch_size))
        description = '[i] Processing samples'

        for x, idxs in tqdm(generator,
                            total=n_sample_batches,
                            desc=description,
                            unit='batches'):
            feed = {net.image_input: x, net.keep_prob: 1}
            enc_boxes = sess.run(net.result, feed_dict=feed)

            # Process the predictions
            for i in range(enc_boxes.shape[0]):
                boxes = decode_boxes(enc_boxes[i], anchors, args.threshold,
                                     lid2name, None)
                boxes = suppress_overlaps(boxes)[:200]
                filename = files[idxs[i]]
                basename = os.path.basename(filename)

                # Annotate samples
                if args.annotate:
                    img = cv2.imread(filename)
                    for box in boxes:
                        draw_box(img, box[1], colors[box[1].label])
                    fn = args.output_dir + '/images/' + basename
                    cv2.imwrite(fn, img)

                # Dump the predictions
                if args.dump_predictions:
                    raw_fn = args.output_dir + '/' + basename + '.npy'
                    np.save(raw_fn, enc_boxes[i])

                # Add predictions to the stats calculator and to the summary
                if args.compute_stats:
                    ap_calc.add_detections(samples[idxs[i]].boxes, boxes)

                if args.summary:
                    summary.add_detections(filename, boxes)

    # Compute and print the stats
    if args.compute_stats:
        aps = ap_calc.compute_aps()
        for k, v in aps.items():
            print('[i] AP [{0}]: {1:.3f}'.format(k, v))
        print('[i] mAP: {0:.3f}'.format(APs2mAP(aps)))

    # Write the summary files
    if args.summary:
        summary.write_summary(args.output_dir + "/summaries")

    print('[i] All done.')
    return 0