def main(): #--------------------------------------------------------------------------- # Parse the commandline #--------------------------------------------------------------------------- parser = argparse.ArgumentParser(description='Train the SSD') parser.add_argument('--name', default='test', help='project name') parser.add_argument('--data-dir', default='pascal-voc', help='data directory') parser.add_argument('--vgg-dir', default='vgg_graph', help='directory for the VGG-16 model') parser.add_argument('--epochs', type=int, default=200, help='number of training epochs') parser.add_argument('--batch-size', type=int, default=8, help='batch size') parser.add_argument('--tensorboard-dir', default="tb", help='name of the tensorboard data directory') parser.add_argument('--checkpoint-interval', type=int, default=5, help='checkpoint interval') parser.add_argument('--lr-values', type=str, default='0.00075;0.0001;0.00001', help='learning rate values') parser.add_argument('--lr-boundaries', type=str, default='320000;400000', help='learning rate chage boundaries (in batches)') parser.add_argument('--momentum', type=float, default=0.9, help='momentum for the optimizer') parser.add_argument('--weight-decay', type=float, default=0.0005, help='L2 normalization factor') parser.add_argument('--continue-training', type=str2bool, default='False', help='continue training from the latest checkpoint') parser.add_argument('--num-workers', type=int, default=mp.cpu_count(), help='number of parallel generators') args = parser.parse_args() print('[i] Project name: ', args.name) print('[i] Data directory: ', args.data_dir) print('[i] VGG directory: ', args.vgg_dir) print('[i] # epochs: ', args.epochs) print('[i] Batch size: ', args.batch_size) print('[i] Tensorboard directory:', args.tensorboard_dir) print('[i] Checkpoint interval: ', args.checkpoint_interval) print('[i] Learning rate values: ', args.lr_values) print('[i] Learning rate boundaries: ', args.lr_boundaries) print('[i] Momentum: ', args.momentum) print('[i] Weight decay: ', args.weight_decay) print('[i] Continue: ', args.continue_training) print('[i] Number of workers: ', args.num_workers) #--------------------------------------------------------------------------- # Find an existing checkpoint #--------------------------------------------------------------------------- start_epoch = 0 if args.continue_training: state = tf.train.get_checkpoint_state(args.name) if state is None: print('[!] No network state found in ' + args.name) return 1 ckpt_paths = state.all_model_checkpoint_paths if not ckpt_paths: print('[!] No network state found in ' + args.name) return 1 last_epoch = None checkpoint_file = None for ckpt in ckpt_paths: ckpt_num = os.path.basename(ckpt).split('.')[0][1:] try: ckpt_num = int(ckpt_num) except ValueError: continue if last_epoch is None or last_epoch < ckpt_num: last_epoch = ckpt_num checkpoint_file = ckpt if checkpoint_file is None: print('[!] No checkpoints found, cannot continue!') return 1 metagraph_file = checkpoint_file + '.meta' if not os.path.exists(metagraph_file): print('[!] Cannot find metagraph', metagraph_file) return 1 start_epoch = last_epoch #--------------------------------------------------------------------------- # Create a project directory #--------------------------------------------------------------------------- else: try: print('[i] Creating directory {}...'.format(args.name)) os.makedirs(args.name) except (IOError) as e: print('[!]', str(e)) return 1 print('[i] Starting at epoch: ', start_epoch+1) #--------------------------------------------------------------------------- # Configure the training data #--------------------------------------------------------------------------- print('[i] Configuring the training data...') try: td = TrainingData(args.data_dir) print('[i] # training samples: ', td.num_train) print('[i] # validation samples: ', td.num_valid) print('[i] # classes: ', td.num_classes) print('[i] Image size: ', td.preset.image_size) except (AttributeError, RuntimeError) as e: print('[!] Unable to load training data:', str(e)) return 1 #--------------------------------------------------------------------------- # Create the network #--------------------------------------------------------------------------- with tf.Session() as sess: print('[i] Creating the model...') n_train_batches = int(math.ceil(td.num_train/args.batch_size)) n_valid_batches = int(math.ceil(td.num_valid/args.batch_size)) global_step = None if start_epoch == 0: lr_values = args.lr_values.split(';') try: lr_values = [float(x) for x in lr_values] except ValueError: print('[!] Learning rate values must be floats') sys.exit(1) lr_boundaries = args.lr_boundaries.split(';') try: lr_boundaries = [int(x) for x in lr_boundaries] except ValueError: print('[!] Learning rate boundaries must be ints') sys.exit(1) ret = compute_lr(lr_values, lr_boundaries) learning_rate, global_step = ret net = SSDVGG(sess, td.preset) if start_epoch != 0: net.build_from_metagraph(metagraph_file, checkpoint_file) net.build_optimizer_from_metagraph() else: net.build_from_vgg(args.vgg_dir, td.num_classes) net.build_optimizer(learning_rate=learning_rate, global_step=global_step, weight_decay=args.weight_decay, momentum=args.momentum) initialize_uninitialized_variables(sess) #----------------------------------------------------------------------- # Create various helpers #----------------------------------------------------------------------- summary_writer = tf.summary.FileWriter(args.tensorboard_dir, sess.graph) saver = tf.train.Saver(max_to_keep=20) anchors = get_anchors_for_preset(td.preset) training_ap_calc = APCalculator() validation_ap_calc = APCalculator() #----------------------------------------------------------------------- # Summaries #----------------------------------------------------------------------- restore = start_epoch != 0 training_ap = PrecisionSummary(sess, summary_writer, 'training', td.lname2id.keys(), restore) validation_ap = PrecisionSummary(sess, summary_writer, 'validation', td.lname2id.keys(), restore) training_imgs = ImageSummary(sess, summary_writer, 'training', td.label_colors, restore) validation_imgs = ImageSummary(sess, summary_writer, 'validation', td.label_colors, restore) training_loss = LossSummary(sess, summary_writer, 'training', td.num_train, restore) validation_loss = LossSummary(sess, summary_writer, 'validation', td.num_valid, restore) #----------------------------------------------------------------------- # Get the initial snapshot of the network #----------------------------------------------------------------------- net_summary_ops = net.build_summaries(restore) if start_epoch == 0: net_summary = sess.run(net_summary_ops) summary_writer.add_summary(net_summary, 0) summary_writer.flush() #----------------------------------------------------------------------- # Cycle through the epoch #----------------------------------------------------------------------- print('[i] Training...') for e in range(start_epoch, args.epochs): training_imgs_samples = [] validation_imgs_samples = [] #------------------------------------------------------------------- # Train #------------------------------------------------------------------- generator = td.train_generator(args.batch_size, args.num_workers) description = '[i] Train {:>2}/{}'.format(e+1, args.epochs) for x, y, gt_boxes in tqdm(generator, total=n_train_batches, desc=description, unit='batches'): if len(training_imgs_samples) < 3: saved_images = np.copy(x[:3]) feed = {net.image_input: x, net.labels: y} result, loss_batch, _ = sess.run([net.result, net.losses, net.optimizer], feed_dict=feed) if math.isnan(loss_batch['confidence']): print('[!] Confidence loss is NaN.') training_loss.add(loss_batch, x.shape[0]) if e == 0: continue for i in range(result.shape[0]): boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name) boxes = suppress_overlaps(boxes) training_ap_calc.add_detections(gt_boxes[i], boxes) if len(training_imgs_samples) < 3: training_imgs_samples.append((saved_images[i], boxes)) #------------------------------------------------------------------- # Validate #------------------------------------------------------------------- generator = td.valid_generator(args.batch_size, args.num_workers) description = '[i] Valid {:>2}/{}'.format(e+1, args.epochs) for x, y, gt_boxes in tqdm(generator, total=n_valid_batches, desc=description, unit='batches'): feed = {net.image_input: x, net.labels: y} result, loss_batch = sess.run([net.result, net.losses], feed_dict=feed) validation_loss.add(loss_batch, x.shape[0]) if e == 0: continue for i in range(result.shape[0]): boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name) boxes = suppress_overlaps(boxes) validation_ap_calc.add_detections(gt_boxes[i], boxes) if len(validation_imgs_samples) < 3: validation_imgs_samples.append((np.copy(x[i]), boxes)) #------------------------------------------------------------------- # Write summaries #------------------------------------------------------------------- training_loss.push(e+1) validation_loss.push(e+1) net_summary = sess.run(net_summary_ops) summary_writer.add_summary(net_summary, e+1) APs = training_ap_calc.compute_aps() mAP = APs2mAP(APs) training_ap.push(e+1, mAP, APs) APs = validation_ap_calc.compute_aps() mAP = APs2mAP(APs) validation_ap.push(e+1, mAP, APs) training_ap_calc.clear() validation_ap_calc.clear() training_imgs.push(e+1, training_imgs_samples) validation_imgs.push(e+1, validation_imgs_samples) summary_writer.flush() #------------------------------------------------------------------- # Save a checktpoint #------------------------------------------------------------------- if (e+1) % args.checkpoint_interval == 0: checkpoint = '{}/e{}.ckpt'.format(args.name, e+1) saver.save(sess, checkpoint) print('[i] Checkpoint saved:', checkpoint) checkpoint = '{}/final.ckpt'.format(args.name) saver.save(sess, checkpoint) print('[i] Checkpoint saved:', checkpoint) return 0
def main(): #--------------------------------------------------------------------------- # Parse the commandline #--------------------------------------------------------------------------- parser = argparse.ArgumentParser(description='Train the SSD') parser.add_argument('--name', default='test-combined-run1-6Dec', help='project name') parser.add_argument('--data-dir', default='VOC', help='data directory') parser.add_argument('--vgg-dir', default='vgg_graph', help='directory for the VGG-16 model') parser.add_argument('--epochs', type=int, default=300, help='number of training epochs') parser.add_argument('--batch-size', type=int, default=8, help='batch size') parser.add_argument('--batch_size_val', type=int, default=1, help='batch size val') parser.add_argument('--tensorboard-dir', default="tb-combined-run1-6Dec", help='name of the tensorboard data directory') parser.add_argument('--checkpoint-interval', type=int, default=10, help='checkpoint interval') parser.add_argument('--lr-values', type=str, default='0.00001; 0.000001;0.00001', help='learning rate values') parser.add_argument('--lr-boundaries', type=str, default='18300;36600', help='learning rate chage boundaries (in batches)') parser.add_argument('--momentum', type=float, default=0.9, help='momentum for the optimizer') parser.add_argument('--weight-decay', type=float, default=1e-6, help='L2 normalization factor') parser.add_argument('--continue-training', type=str2bool, default='False', help='continue training from the latest checkpoint') parser.add_argument('--num-workers', type=int, default=6, help='number of parallel generators') args = parser.parse_args() print('[i] Project name: ', args.name) print('[i] Data directory: ', args.data_dir) print('[i] VGG directory: ', args.vgg_dir) print('[i] # epochs: ', args.epochs) print('[i] Batch size: ', args.batch_size) print('[i] Batch size val: ', args.batch_size_val) print('[i] Tensorboard directory:', args.tensorboard_dir) print('[i] Checkpoint interval: ', args.checkpoint_interval) print('[i] Learning rate values: ', args.lr_values) print('[i] Learning rate boundaries: ', args.lr_boundaries) print('[i] Momentum: ', args.momentum) print('[i] Weight decay: ', args.weight_decay) print('[i] Continue: ', args.continue_training) print('[i] Number of workers: ', args.num_workers) #--------------------------------------------------------------------------- # Find an existing checkpoint #--------------------------------------------------------------------------- start_epoch = 0 if args.continue_training: state = tf.train.get_checkpoint_state(args.name) if state is None: print('[!] No network state found in ' + args.name) return 1 ckpt_paths = state.all_model_checkpoint_paths if not ckpt_paths: print('[!] No network state found in ' + args.name) return 1 last_epoch = None checkpoint_file = None for ckpt in ckpt_paths: ckpt_num = os.path.basename(ckpt).split('.')[0][1:] try: ckpt_num = int(ckpt_num) except ValueError: continue if last_epoch is None or last_epoch < ckpt_num: last_epoch = ckpt_num checkpoint_file = ckpt if checkpoint_file is None: print('[!] No checkpoints found, cannot continue!') return 1 metagraph_file = checkpoint_file + '.meta' if not os.path.exists(metagraph_file): print('[!] Cannot find metagraph', metagraph_file) return 1 start_epoch = last_epoch #--------------------------------------------------------------------------- # Create a project directory #--------------------------------------------------------------------------- else: try: if not os.path.exists(args.name): print('[i] Creating directory {}...'.format(args.name)) os.makedirs(args.name) except (IOError) as e: print('[!]', str(e)) return 1 print('[i] Starting at epoch:', start_epoch + 1) #--------------------------------------------------------------------------- # Configure the training data #--------------------------------------------------------------------------- print('[i] Configuring the training data...') try: td = TrainingData(args.data_dir) print('[i] # training samples: ', td.num_train) print('[i] # validation samples: ', td.num_valid) print('[i] # classes: ', td.num_classes) print('[i] Image size: ', td.preset.image_size) except (AttributeError, RuntimeError) as e: print('[!] Unable to load training data:', str(e)) return 1 #--------------------------------------------------------------------------- # Create the network #--------------------------------------------------------------------------- config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.8 with tf.Session(config=config) as sess: print('[i] Creating the model...') n_train_batches = int(math.ceil(td.num_train / args.batch_size)) n_valid_batches = int(math.ceil(td.num_valid / args.batch_size_val)) global_step = None if start_epoch == 0: lr_values = args.lr_values.split(';') try: lr_values = [float(x) for x in lr_values] except ValueError: print('[!] Learning rate values must be floats') sys.exit(1) lr_boundaries = args.lr_boundaries.split(';') try: lr_boundaries = [int(x) for x in lr_boundaries] except ValueError: print('[!] Learning rate boundaries must be ints') sys.exit(1) ret = compute_lr(lr_values, lr_boundaries) learning_rate, global_step = ret net = SSDVGG(sess, td.preset) if start_epoch != 0: net.build_from_metagraph(metagraph_file, checkpoint_file) net.build_optimizer_from_metagraph() else: net.build_from_vgg(args.vgg_dir, td.num_classes) net.build_optimizer(learning_rate=learning_rate, global_step=global_step, weight_decay=args.weight_decay, momentum=args.momentum) initialize_uninitialized_variables(sess) #----------------------------------------------------------------------- # Create various helpers #----------------------------------------------------------------------- summary_writer = tf.summary.FileWriter(args.tensorboard_dir, sess.graph) saver = tf.train.Saver(max_to_keep=20) scores_list = [] class_scores_list = [] precision_list = [] recall_list = [] f1_list = [] iou_list = [] scores_list_training = [] class_scores_list_training = [] precision_list_training = [] recall_list_training = [] f1_list_training = [] iou_list_training = [] avg_loss_per_epoch = [] avg_scores_per_epoch = [] avg_iou_per_epoch = [] avg_iou_per_epoch_training = [] anchors = get_anchors_for_preset(td.preset) training_ap_calc = APCalculator() validation_ap_calc = APCalculator() #----------------------------------------------------------------------- # Summaries #----------------------------------------------------------------------- restore = start_epoch != 0 training_ap = PrecisionSummary(sess, summary_writer, 'training', td.lname2id.keys(), restore) validation_ap = PrecisionSummary(sess, summary_writer, 'validation', td.lname2id.keys(), restore) training_imgs = ImageSummary(sess, summary_writer, 'training', td.label_colors, restore) validation_imgs = ImageSummary(sess, summary_writer, 'validation', td.label_colors, restore) training_loss = LossSummary(sess, summary_writer, 'training', td.num_train, restore) validation_loss = LossSummary(sess, summary_writer, 'validation', td.num_valid, restore) #----------------------------------------------------------------------- # Get the initial snapshot of the network #----------------------------------------------------------------------- net_summary_ops = net.build_summaries(restore) if start_epoch == 0: net_summary = sess.run(net_summary_ops) summary_writer.add_summary(net_summary, 0) summary_writer.flush() #----------------------------------------------------------------------- # Cycle through the epoch #----------------------------------------------------------------------- print('[i] Training...') for e in range(start_epoch, args.epochs): training_imgs_samples = [] validation_imgs_samples = [] #------------------------------------------------------------------- # Train #------------------------------------------------------------------- generator = td.train_generator(args.batch_size, args.num_workers) description = '[i] Train {:>2}/{}'.format(e + 1, args.epochs) for x, y, gt_boxes, img_seg_gt, imgseg_gt_to_compare in tqdm( generator, total=n_train_batches, desc=description, unit='batches'): if len(training_imgs_samples) < 3: saved_images = np.copy(x[:3]) feed = { net.image_input: x, net.labels: y, net.label_seg_gt: img_seg_gt } output_seg, result, loss_batch, _ = sess.run( [net.logits_seg, net.result, net.losses, net.optimizer], feed_dict=feed) output_image = np.array(output_seg[0, :, :, :]) output_seg_rev = reverse_one_hot(output_image) output_seg_output = colour_code_segmentation(output_seg_rev) if math.isnan(loss_batch['total']): print('[!] total loss is NaN.') training_loss.add(loss_batch, x.shape[0]) if e == 0: continue for i in range(result.shape[0]): boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name) boxes = suppress_overlaps(boxes) training_ap_calc.add_detections(gt_boxes[i], boxes) if len(training_imgs_samples) < 3: training_imgs_samples.append((saved_images[i], boxes)) #------------------------------------------------------------------- # Validate #------------------------------------------------------------------- generator = td.valid_generator(args.batch_size_val, args.num_workers) description = '[i] Valid {:>2}/{}'.format(e + 1, args.epochs) if not os.path.isdir("%s/%04d" % ("Combined-fcn8-run1", e)): os.makedirs("%s/%04d" % ("Combined-fcn8-run1", e)) counter_val = 0 target = open("%s/%04d/val_scores.csv" % ("Combined-fcn8-run1", e), 'w') target.write( "val_name, avg_accuracy, precision, recall, f1 score, mean iou, %s\n" % str(counter_val)) for x, y, gt_boxes, img_seg_gt, imgseg_gt_to_compare in tqdm( generator, total=n_valid_batches, desc=description, unit='batches'): gt_rev_onehot = reverse_one_hot( one_hot_encode(np.squeeze(imgseg_gt_to_compare))) feed = { net.image_input: x, net.labels: y, net.label_seg_gt: img_seg_gt } result, output_seg, loss_batch = sess.run( [net.result, net.logits_seg, net.losses], feed_dict=feed) output_image = np.array(output_seg[0, :, :, :]) output_seg_rev = reverse_one_hot(output_image) output_seg_output = colour_code_segmentation(output_seg_rev) accuracy, class_accuracies, prec, rec, f1, iou = evaluate_segmentation( pred=output_seg_rev, label=gt_rev_onehot, num_classes=21) filename = str(counter_val) target.write("%s, %f, %f, %f, %f, %f" % (filename, accuracy, prec, rec, f1, iou)) for item in class_accuracies: target.write(", %f" % (item)) target.write("\n") scores_list.append(accuracy) class_scores_list.append(class_accuracies) precision_list.append(prec) recall_list.append(rec) f1_list.append(f1) iou_list.append(iou) original_gt = colour_code_segmentation(gt_rev_onehot) cv2.imwrite( "%s/%04d/%s_gt.png" % ("Combined-fcn8-run1", e, filename), original_gt) cv2.imwrite( "%s/%04d/%s_pred.png" % ("Combined-fcn8-run1", e, filename), output_seg_output) counter_val += 1 validation_loss.add(loss_batch, x.shape[0]) if e == 0: continue for i in range(result.shape[0]): boxes = decode_boxes(result[i], anchors, 0.5, td.lid2name) boxes = suppress_overlaps(boxes) validation_ap_calc.add_detections(gt_boxes[i], boxes) if len(validation_imgs_samples) < 3: validation_imgs_samples.append((np.copy(x[i]), boxes)) target.close() #------------------------------------------------------------------- #Write summaries #------------------------------------------------------------------- avg_score = np.mean(scores_list) avg_scores_per_epoch.append(avg_score) avg_precision = np.mean(precision_list) avg_recall = np.mean(recall_list) avg_f1 = np.mean(f1_list) avg_iou = np.mean(iou_list) avg_iou_per_epoch.append(avg_iou) print("\nAverage validation accuracy for epoch # %04d = %f" % (e, avg_score)) print("Validation precision = ", avg_precision) print("Validation recall = ", avg_recall) print("Validation F1 score = ", avg_f1) print("Validation IoU score = ", avg_iou) training_loss.push(e + 1) validation_loss.push(e + 1) net_summary = sess.run(net_summary_ops) summary_writer.add_summary(net_summary, e + 1) APs = training_ap_calc.compute_aps() mAP = APs2mAP(APs) training_ap.push(e + 1, mAP, APs) APs = validation_ap_calc.compute_aps() mAP = APs2mAP(APs) validation_ap.push(e + 1, mAP, APs) training_ap_calc.clear() validation_ap_calc.clear() training_imgs.push(e + 1, training_imgs_samples) validation_imgs.push(e + 1, validation_imgs_samples) summary_writer.flush() #------------------------------------------------------------------- # Save a checktpoint #------------------------------------------------------------------- if (e + 1) % args.checkpoint_interval == 0: checkpoint = '{}/e{}.ckpt'.format(args.name, e + 1) saver.save(sess, checkpoint) print('[i] Checkpoint saved:', checkpoint) checkpoint = '{}/final.ckpt'.format(args.name) saver.save(sess, checkpoint) print('[i] Checkpoint saved:', checkpoint) return 0
def main(): # Parse commandline parser = argparse.ArgumentParser(description='SSD inference') parser.add_argument("files", nargs="*") parser.add_argument('--checkpoint-dir', default='pascal-voc/checkpoints', help='project name') parser.add_argument('--checkpoint', type=int, default=-1, help='checkpoint to restore; -1 is the most recent') parser.add_argument('--data-source', default="pascal-voc", help='Use test files from the data source') parser.add_argument('--data-dir', default='pascal-voc', help='Use test files from the data source') parser.add_argument('--training-data', default='pascal-voc/training-data.pkl', help='Information about parameters used for training') parser.add_argument('--output-dir', default='pascal-voc/annotated/train', help='directory for the resulting images') parser.add_argument('--annotate', type=str2bool, default='False', help="Annotate the data samples") parser.add_argument('--dump-predictions', type=str2bool, default='False', help="Dump raw predictions") parser.add_argument('--summary', type=str2bool, default='True', help='dump the detections in Pascal VOC format') parser.add_argument('--compute-stats', type=str2bool, default='True', help="Compute the mAP stats") parser.add_argument('--sample', default='train', choices=['train', 'valid']) parser.add_argument('--batch-size', type=int, default=32, help='batch size') parser.add_argument('--threshold', type=float, default=0.5, help='confidence threshold') args = parser.parse_args() # Print parameters print('[i] Checkpoint directory: ', args.checkpoint_dir) print('[i] Data source: ', args.data_source) print('[i] Data directory: ', args.data_dir) print('[i] Training data: ', args.training_data) print('[i] Output directory: ', args.output_dir) print('[i] Annotate: ', args.annotate) print('[i] Dump predictions: ', args.dump_predictions) print('[i] Summary: ', args.summary) print('[i] Compute state: ', args.compute_stats) print('[i] Sample: ', args.sample) print('[i] Batch size: ', args.batch_size) print('[i] Threshold: ', args.threshold) # Check if we can get the checkpoint state = tf.train.get_checkpoint_state(args.checkpoint_dir) if state is None: print('[!] No network state found in ' + args.checkpoint_dir) return 1 try: checkpoint_file = state.all_model_checkpoint_paths[args.checkpoint] except IndexError: print('[!] Cannot find checkpoint ' + str(args.checkpoint_file)) return 1 metagraph_file = checkpoint_file + '.meta' if not os.path.exists(metagraph_file): print('[!] Cannot find metagraph ' + metagraph_file) return 1 # Load the training data parameters try: with open(args.training_data, 'rb') as f: data = pickle.load(f) preset = data['preset'] colors = data['colors'] lid2name = data['lid2name'] image_size = preset.image_size anchors = get_anchors_for_preset(preset) except (FileNotFoundError, IOError, KeyError) as e: print('[!] Unable to load training data:', str(e)) return 1 # Load the samples according to data source and sample type try: if args.sample == 'train': with open(args.data_dir + '/train-samples.pkl', 'rb') as f: samples = pickle.load(f) else: with open(args.data_dir + '/valid-samples.pkl', 'rb') as f: samples = pickle.load(f) num_samples = len(samples) print('[i] # samples: ', num_samples) except (ImportError, AttributeError, RuntimeError) as e: print('[!] Unable to load data source:', str(e)) return 1 # Create a list of files to analyse and make sure that the output directory exists files = [] for sample in samples: files.append(sample.filename) files = list(filter(lambda x: os.path.exists(x), files)) if files: if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # Print model and dataset stats print('[i] Network checkpoint:', checkpoint_file) print('[i] Metagraph file: ', metagraph_file) print('[i] Image size: ', image_size) print('[i] Number of files: ', len(files)) # Create the network if args.compute_stats: ap_calc = APCalculator() if args.summary: summary = PascalSummary() with tf.Session() as sess: print('[i] Creating the model...') net = SSDVGG(sess, preset) net.build_from_metagraph(metagraph_file, checkpoint_file) # Process the images generator = sample_generator(files, image_size, args.batch_size) n_sample_batches = int(math.ceil(len(files) / args.batch_size)) description = '[i] Processing samples' for x, idxs in tqdm(generator, total=n_sample_batches, desc=description, unit='batches'): feed = {net.image_input: x, net.keep_prob: 1} enc_boxes = sess.run(net.result, feed_dict=feed) # Process the predictions for i in range(enc_boxes.shape[0]): boxes = decode_boxes(enc_boxes[i], anchors, args.threshold, lid2name, None) boxes = suppress_overlaps(boxes)[:200] filename = files[idxs[i]] basename = os.path.basename(filename) # Annotate samples if args.annotate: img = cv2.imread(filename) for box in boxes: draw_box(img, box[1], colors[box[1].label]) fn = args.output_dir + '/images/' + basename cv2.imwrite(fn, img) # Dump the predictions if args.dump_predictions: raw_fn = args.output_dir + '/' + basename + '.npy' np.save(raw_fn, enc_boxes[i]) # Add predictions to the stats calculator and to the summary if args.compute_stats: ap_calc.add_detections(samples[idxs[i]].boxes, boxes) if args.summary: summary.add_detections(filename, boxes) # Compute and print the stats if args.compute_stats: aps = ap_calc.compute_aps() for k, v in aps.items(): print('[i] AP [{0}]: {1:.3f}'.format(k, v)) print('[i] mAP: {0:.3f}'.format(APs2mAP(aps))) # Write the summary files if args.summary: summary.write_summary(args.output_dir + "/summaries") print('[i] All done.') return 0