def train_network(gpu_config): capsnet = Caps3d() with tf.Session(graph=capsnet.graph, config=gpu_config) as sess: tf.global_variables_initializer().run() get_num_params() config.clear_output() n_eps_after_acc, best_loss = -1, 100000 print('Training on UCF101') for ep in range(1, config.n_epochs + 1): print(20 * '*', 'epoch', ep, 20 * '*') # trains network for one epoch data_gen = TrainDataGen(config.wait_for_data, frame_skip=config.frame_skip) margin_loss, seg_loss, acc = capsnet.train(sess, data_gen) config.write_output('CL: %.4f. SL: %.4f. Acc: %.4f\n' % (margin_loss, seg_loss, acc)) # increments the margin if ep % config.n_eps_for_m == 0: capsnet.cur_m += config.m_delta capsnet.cur_m = min(capsnet.cur_m, 0.9) # only validates after a certain number of epochs and when the training accuracy is greater than a threshold # this is mainly used to save time, since validation takes about 10 minutes if (acc >= config.acc_for_eval or n_eps_after_acc >= 0) and ep >= config.n_eps_until_eval: n_eps_after_acc += 1 # validates the network if (acc >= config.acc_for_eval and n_eps_after_acc % config.n_eps_for_eval == 0) or ep == config.n_epochs: data_gen = TestDataGen(config.wait_for_data, frame_skip=1) margin_loss, seg_loss, accuracy, _ = capsnet.eval( sess, data_gen, validation=True) config.write_output( 'Validation\tCL: %.4f. SL: %.4f. Acc: %.4f.\n' % (margin_loss, seg_loss, accuracy)) # saves the network when validation loss in minimized t_loss = margin_loss + seg_loss if t_loss < best_loss: best_loss = t_loss try: capsnet.save(sess, config.save_file_name) config.write_output('Saved Network\n') except: print('Failed to save network!!!') # calculate final test accuracy, f-mAP, and v-mAP iou()
def iou(): """ Calculates the accuracy, f-mAP, and v-mAP over the test set """ gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True capsnet = Caps3d() with tf.Session(graph=capsnet.graph, config=gpu_config) as sess: tf.global_variables_initializer().run() capsnet.load(sess, config.save_file_name) data_gen = TestDataGen(config.wait_for_data) n_correct, n_vids, n_tot_frames = 0, np.zeros( (config.n_classes, 1)), np.zeros((config.n_classes, 1)) frame_ious = np.zeros((config.n_classes, 20)) video_ious = np.zeros((config.n_classes, 20)) iou_threshs = np.arange(0, 20, dtype=np.float32) / 20 while data_gen.has_data(): video, bbox, label = data_gen.get_next_video() f_skip = config.frame_skip clips = [] n_frames = video.shape[0] for i in range(0, video.shape[0], 8 * f_skip): for j in range(f_skip): b_vid, b_bbox = [], [] for k in range(8): ind = i + j + k * f_skip if ind >= n_frames: b_vid.append( np.zeros((1, 112, 112, 3), dtype=np.float32)) b_bbox.append( np.zeros((1, 112, 112, 1), dtype=np.float32)) else: b_vid.append(video[ind:ind + 1, :, :, :]) b_bbox.append(bbox[ind:ind + 1, :, :, :]) clips.append((np.concatenate(b_vid, axis=0), np.concatenate(b_bbox, axis=0), label)) if np.sum(clips[-1][1]) == 0: clips.pop(-1) if len(clips) == 0: print('Video has no bounding boxes') continue batches, gt_segmentations = [], [] for i in range(0, len(clips), config.batch_size): x_batch, bb_batch, y_batch = [], [], [] for j in range(i, min(i + config.batch_size, len(clips))): x, bb, y = clips[j] x_batch.append(x) bb_batch.append(bb) y_batch.append(y) batches.append((x_batch, bb_batch, y_batch)) gt_segmentations.append(np.stack(bb_batch)) gt_segmentations = np.concatenate(gt_segmentations, axis=0) gt_segmentations = gt_segmentations.reshape( (-1, 112, 112, 1)) # Shape N_FRAMES, 112, 112, 1 segmentations, predictions = [], [] for x_batch, bb_batch, y_batch in batches: segmentation, pred = sess.run( [capsnet.segment_layer_sig, capsnet.digit_preds], feed_dict={ capsnet.x_input: x_batch, capsnet.y_input: y_batch, capsnet.m: 0.9, capsnet.is_train: False }) segmentations.append(segmentation) predictions.append(pred) predictions = np.concatenate(predictions, axis=0) predictions = predictions.reshape((-1, config.n_classes)) fin_pred = np.mean(predictions, axis=0) fin_pred = np.argmax(fin_pred) if fin_pred == label: n_correct += 1 pred_segmentations = np.concatenate(segmentations, axis=0) pred_segmentations = pred_segmentations.reshape((-1, 112, 112, 1)) pred_segmentations = (pred_segmentations >= 0.5).astype(np.int32) seg_plus_gt = pred_segmentations + gt_segmentations vid_inter, vid_union = 0, 0 # calculates f_map for i in range(gt_segmentations.shape[0]): frame_gt = gt_segmentations[i] if np.sum(frame_gt) == 0: continue n_tot_frames[label] += 1 inter = np.count_nonzero(seg_plus_gt[i] == 2) union = np.count_nonzero(seg_plus_gt[i]) vid_inter += inter vid_union += union i_over_u = inter / union for k in range(iou_threshs.shape[0]): if i_over_u >= iou_threshs[k]: frame_ious[label, k] += 1 n_vids[label] += 1 i_over_u = vid_inter / vid_union for k in range(iou_threshs.shape[0]): if i_over_u >= iou_threshs[k]: video_ious[label, k] += 1 if np.sum(n_vids) % 100 == 0: print('Finished %d videos' % np.sum(n_vids)) print('Accuracy:', n_correct / np.sum(n_vids)) config.write_output('Test Accuracy: %.4f\n' % float(n_correct / np.sum(n_vids))) fAP = frame_ious / n_tot_frames fmAP = np.mean(fAP, axis=0) vAP = video_ious / n_vids vmAP = np.mean(vAP, axis=0) print('IoU f-mAP:') config.write_output('IoU f-mAP:\n') for i in range(20): print(iou_threshs[i], fmAP[i]) config.write_output('%.4f\t%.4f\n' % (iou_threshs[i], fmAP[i])) config.write_output(str(fAP[:, 10]) + '\n') print(fAP[:, 10]) print('IoU v-mAP:') config.write_output('IoU v-mAP:\n') for i in range(20): print(iou_threshs[i], vmAP[i]) config.write_output('%.4f\t%.4f\n' % (iou_threshs[i], vmAP[i])) config.write_output(str(vAP[:, 10]) + '\n') print(vAP[:, 10])
def inference(video_name): gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True capsnet = Caps3d() with tf.Session(graph=capsnet.graph, config=gpu_config) as sess: tf.global_variables_initializer().run() capsnet.load(sess, config.save_file_name) video = vread(video_name) n_frames = video.shape[0] crop_size = (112, 112) # assumes a given aspect ratio of (240, 320). If given a cropped video, then no resizing occurs if video.shape[1] != 112 and video.shape[2] != 112: h, w = 120, 160 video_res = np.zeros((n_frames, 120, 160, 3)) for f in range(n_frames): video_res[f] = imresize(video[f], (120, 160)) else: h, w = 112, 112 video_res = video # crops video to 112x112 margin_h = h - crop_size[0] h_crop_start = int(margin_h / 2) margin_w = w - crop_size[1] w_crop_start = int(margin_w / 2) video_cropped = video_res[:, h_crop_start:h_crop_start+crop_size[0], w_crop_start:w_crop_start+crop_size[1], :] print('Saving Cropped Video') vwrite('cropped.avi', video_cropped) video_cropped = video_cropped/255. segmentation_output = np.zeros((n_frames, crop_size[0], crop_size[1], 1)) f_skip = config.frame_skip for i in range(0, n_frames, 8*f_skip): # if frames are skipped (subsampled) during training, they should also be skipped at test time # creates a batch of video clips x_batch = [[] for i in range(f_skip)] for k in range(f_skip*8): if i + k >= n_frames: x_batch[k % f_skip].append(np.zeros_like(video_cropped[-1])) else: x_batch[k % f_skip].append(video_cropped[i+k]) x_batch = [np.stack(x, axis=0) for x in x_batch] # runs the network to get segmentations seg_out = sess.run(capsnet.segment_layer_sig, feed_dict={capsnet.x_input: x_batch, capsnet.is_train: False, capsnet.y_input: np.ones((f_skip,), np.int32)*-1}) # collects the segmented frames into the correct order for k in range(f_skip * 8): if i + k >= n_frames: continue segmentation_output[i+k] = seg_out[k % f_skip][k//f_skip] # Final segmentation output segmentation_output = (segmentation_output >= 0.5).astype(np.int32) # Highlights the video based on the segmentation alpha = 0.5 color = np.zeros((3,)) + [0.0, 0, 1.0] masked_vid = np.where(np.tile(segmentation_output, [1, 1, 3]) == 1, video_cropped * (1 - alpha) + alpha * color, video_cropped) print('Saving Segmented Video') vwrite('segmented_vid.avi', (masked_vid * 255).astype(np.uint8))
def train_network(gpu_config): capsnet = Caps3d() with tf.compat.v1.Session(graph=capsnet.graph, config=gpu_config) as sess: tf.compat.v1.global_variables_initializer().run() get_num_params() if config.start_at_epoch <= 1: config.clear_output() else: capsnet.load(sess, config.save_file_name % (config.start_at_epoch - 1)) print('Loading from epoch %d.' % (config.start_at_epoch - 1)) n_eps_after_acc, best_loss = -1, 100000 print('Training on UCF101') for ep in range(config.start_at_epoch, config.n_epochs + 1): print(20 * '*', 'epoch', ep, 20 * '*') nan_tries = 0 while nan_tries < 3: # trains network for one epoch data_gen = TrainDataGen(config.wait_for_data, frame_skip=config.frame_skip) margin_loss, seg_loss, acc = capsnet.train(sess, data_gen) if margin_loss < 0 or acc < 0: nan_tries += 1 # capsnet.load(sess, config.save_file_name % 20) # loads in the previous epoch # while data_gen.has_data(): # data_gen.get_batch(config.batch_size) else: config.write_output('CL: %.4f. SL: %.4f. Acc: %.4f\n' % (margin_loss, seg_loss, acc)) break if nan_tries == 3: print('Network cannot be trained. Too many NaN issues.') exit() if ep % config.save_every_n_epochs == 0: try: capsnet.save(sess, config.save_file_name % ep) config.write_output('Saved Network\n') except: print('Failed to save network!!!') # increments the margin if ep % config.n_eps_for_m == 0: capsnet.cur_m += config.m_delta capsnet.cur_m = min(capsnet.cur_m, 0.9) # only validates after a certain number of epochs and when the training accuracy is greater than a threshold # this is mainly used to save time, since validation takes about 10 minutes if (acc >= config.acc_for_eval or n_eps_after_acc >= 0) and ep >= config.n_eps_until_eval: n_eps_after_acc += 1 # validates the network if (acc >= config.acc_for_eval and n_eps_after_acc % config.n_eps_for_eval == 0) or ep == config.n_epochs: # data_gen = TestDataGen(config.wait_for_data, frame_skip=1) # margin_loss, seg_loss, accuracy, _ = capsnet.eval(sess, data_gen, validation=True) # # config.write_output('Validation\tCL: %.4f. SL: %.4f. Acc: %.4f.\n' % # (margin_loss, seg_loss, accuracy)) # # # saves the network when validation loss in minimized # t_loss = margin_loss + seg_loss # if t_loss < best_loss: # best_loss = t_loss try: capsnet.save(sess, config.save_file_name % ep) config.write_output('Saved Network\n') except: print('Failed to save network!!!')