Exemple #1
0
def train_network(gpu_config):
    capsnet = Caps3d()

    with tf.Session(graph=capsnet.graph, config=gpu_config) as sess:
        tf.global_variables_initializer().run()

        get_num_params()
        config.clear_output()

        n_eps_after_acc, best_loss = -1, 100000
        print('Training on UCF101')
        for ep in range(1, config.n_epochs + 1):
            print(20 * '*', 'epoch', ep, 20 * '*')

            # trains network for one epoch
            data_gen = TrainDataGen(config.wait_for_data,
                                    frame_skip=config.frame_skip)
            margin_loss, seg_loss, acc = capsnet.train(sess, data_gen)
            config.write_output('CL: %.4f. SL: %.4f. Acc: %.4f\n' %
                                (margin_loss, seg_loss, acc))

            # increments the margin
            if ep % config.n_eps_for_m == 0:
                capsnet.cur_m += config.m_delta
                capsnet.cur_m = min(capsnet.cur_m, 0.9)

            # only validates after a certain number of epochs and when the training accuracy is greater than a threshold
            # this is mainly used to save time, since validation takes about 10 minutes
            if (acc >= config.acc_for_eval
                    or n_eps_after_acc >= 0) and ep >= config.n_eps_until_eval:
                n_eps_after_acc += 1

            # validates the network
            if (acc >= config.acc_for_eval and n_eps_after_acc %
                    config.n_eps_for_eval == 0) or ep == config.n_epochs:
                data_gen = TestDataGen(config.wait_for_data, frame_skip=1)
                margin_loss, seg_loss, accuracy, _ = capsnet.eval(
                    sess, data_gen, validation=True)

                config.write_output(
                    'Validation\tCL: %.4f. SL: %.4f. Acc: %.4f.\n' %
                    (margin_loss, seg_loss, accuracy))

                # saves the network when validation loss in minimized
                t_loss = margin_loss + seg_loss
                if t_loss < best_loss:
                    best_loss = t_loss
                    try:
                        capsnet.save(sess, config.save_file_name)
                        config.write_output('Saved Network\n')
                    except:
                        print('Failed to save network!!!')

        # calculate final test accuracy, f-mAP, and v-mAP
        iou()
def iou():
    """
    Calculates the accuracy, f-mAP, and v-mAP over the test set
    """
    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True

    capsnet = Caps3d()
    with tf.Session(graph=capsnet.graph, config=gpu_config) as sess:
        tf.global_variables_initializer().run()
        capsnet.load(sess, config.save_file_name)

        data_gen = TestDataGen(config.wait_for_data)

        n_correct, n_vids, n_tot_frames = 0, np.zeros(
            (config.n_classes, 1)), np.zeros((config.n_classes, 1))

        frame_ious = np.zeros((config.n_classes, 20))
        video_ious = np.zeros((config.n_classes, 20))
        iou_threshs = np.arange(0, 20, dtype=np.float32) / 20

        while data_gen.has_data():
            video, bbox, label = data_gen.get_next_video()

            f_skip = config.frame_skip
            clips = []
            n_frames = video.shape[0]
            for i in range(0, video.shape[0], 8 * f_skip):
                for j in range(f_skip):
                    b_vid, b_bbox = [], []
                    for k in range(8):
                        ind = i + j + k * f_skip
                        if ind >= n_frames:
                            b_vid.append(
                                np.zeros((1, 112, 112, 3), dtype=np.float32))
                            b_bbox.append(
                                np.zeros((1, 112, 112, 1), dtype=np.float32))
                        else:
                            b_vid.append(video[ind:ind + 1, :, :, :])
                            b_bbox.append(bbox[ind:ind + 1, :, :, :])

                    clips.append((np.concatenate(b_vid, axis=0),
                                  np.concatenate(b_bbox, axis=0), label))
                    if np.sum(clips[-1][1]) == 0:
                        clips.pop(-1)

            if len(clips) == 0:
                print('Video has no bounding boxes')
                continue

            batches, gt_segmentations = [], []
            for i in range(0, len(clips), config.batch_size):
                x_batch, bb_batch, y_batch = [], [], []
                for j in range(i, min(i + config.batch_size, len(clips))):
                    x, bb, y = clips[j]
                    x_batch.append(x)
                    bb_batch.append(bb)
                    y_batch.append(y)
                batches.append((x_batch, bb_batch, y_batch))
                gt_segmentations.append(np.stack(bb_batch))

            gt_segmentations = np.concatenate(gt_segmentations, axis=0)
            gt_segmentations = gt_segmentations.reshape(
                (-1, 112, 112, 1))  # Shape N_FRAMES, 112, 112, 1

            segmentations, predictions = [], []
            for x_batch, bb_batch, y_batch in batches:
                segmentation, pred = sess.run(
                    [capsnet.segment_layer_sig, capsnet.digit_preds],
                    feed_dict={
                        capsnet.x_input: x_batch,
                        capsnet.y_input: y_batch,
                        capsnet.m: 0.9,
                        capsnet.is_train: False
                    })
                segmentations.append(segmentation)
                predictions.append(pred)

            predictions = np.concatenate(predictions, axis=0)
            predictions = predictions.reshape((-1, config.n_classes))
            fin_pred = np.mean(predictions, axis=0)

            fin_pred = np.argmax(fin_pred)
            if fin_pred == label:
                n_correct += 1

            pred_segmentations = np.concatenate(segmentations, axis=0)
            pred_segmentations = pred_segmentations.reshape((-1, 112, 112, 1))

            pred_segmentations = (pred_segmentations >= 0.5).astype(np.int32)
            seg_plus_gt = pred_segmentations + gt_segmentations

            vid_inter, vid_union = 0, 0
            # calculates f_map
            for i in range(gt_segmentations.shape[0]):
                frame_gt = gt_segmentations[i]
                if np.sum(frame_gt) == 0:
                    continue

                n_tot_frames[label] += 1

                inter = np.count_nonzero(seg_plus_gt[i] == 2)
                union = np.count_nonzero(seg_plus_gt[i])
                vid_inter += inter
                vid_union += union

                i_over_u = inter / union
                for k in range(iou_threshs.shape[0]):
                    if i_over_u >= iou_threshs[k]:
                        frame_ious[label, k] += 1

            n_vids[label] += 1
            i_over_u = vid_inter / vid_union
            for k in range(iou_threshs.shape[0]):
                if i_over_u >= iou_threshs[k]:
                    video_ious[label, k] += 1

            if np.sum(n_vids) % 100 == 0:
                print('Finished %d videos' % np.sum(n_vids))

        print('Accuracy:', n_correct / np.sum(n_vids))
        config.write_output('Test Accuracy: %.4f\n' %
                            float(n_correct / np.sum(n_vids)))

        fAP = frame_ious / n_tot_frames
        fmAP = np.mean(fAP, axis=0)
        vAP = video_ious / n_vids
        vmAP = np.mean(vAP, axis=0)

        print('IoU f-mAP:')
        config.write_output('IoU f-mAP:\n')
        for i in range(20):
            print(iou_threshs[i], fmAP[i])
            config.write_output('%.4f\t%.4f\n' % (iou_threshs[i], fmAP[i]))
        config.write_output(str(fAP[:, 10]) + '\n')
        print(fAP[:, 10])
        print('IoU v-mAP:')
        config.write_output('IoU v-mAP:\n')
        for i in range(20):
            print(iou_threshs[i], vmAP[i])
            config.write_output('%.4f\t%.4f\n' % (iou_threshs[i], vmAP[i]))
        config.write_output(str(vAP[:, 10]) + '\n')
        print(vAP[:, 10])
Exemple #3
0
def inference(video_name):
    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True

    capsnet = Caps3d()
    with tf.Session(graph=capsnet.graph, config=gpu_config) as sess:
        tf.global_variables_initializer().run()
        capsnet.load(sess, config.save_file_name)

        video = vread(video_name)

        n_frames = video.shape[0]
        crop_size = (112, 112)

        # assumes a given aspect ratio of (240, 320). If given a cropped video, then no resizing occurs
        if video.shape[1] != 112 and video.shape[2] != 112:
            h, w = 120, 160

            video_res = np.zeros((n_frames, 120, 160, 3))

            for f in range(n_frames):
                video_res[f] = imresize(video[f], (120, 160))
        else:
            h, w = 112, 112
            video_res = video

        # crops video to 112x112
        margin_h = h - crop_size[0]
        h_crop_start = int(margin_h / 2)
        margin_w = w - crop_size[1]
        w_crop_start = int(margin_w / 2)
        video_cropped = video_res[:, h_crop_start:h_crop_start+crop_size[0], w_crop_start:w_crop_start+crop_size[1], :]

        print('Saving Cropped Video')
        vwrite('cropped.avi', video_cropped)

        video_cropped = video_cropped/255.

        segmentation_output = np.zeros((n_frames, crop_size[0], crop_size[1], 1))
        f_skip = config.frame_skip

        for i in range(0, n_frames, 8*f_skip):
            # if frames are skipped (subsampled) during training, they should also be skipped at test time
            # creates a batch of video clips
            x_batch = [[] for i in range(f_skip)]
            for k in range(f_skip*8):
                if i + k >= n_frames:
                    x_batch[k % f_skip].append(np.zeros_like(video_cropped[-1]))
                else:
                    x_batch[k % f_skip].append(video_cropped[i+k])
            x_batch = [np.stack(x, axis=0) for x in x_batch]

            # runs the network to get segmentations
            seg_out = sess.run(capsnet.segment_layer_sig, feed_dict={capsnet.x_input: x_batch,
                                                                     capsnet.is_train: False,
                                                                     capsnet.y_input: np.ones((f_skip,), np.int32)*-1})

            # collects the segmented frames into the correct order
            for k in range(f_skip * 8):
                if i + k >= n_frames:
                    continue

                segmentation_output[i+k] = seg_out[k % f_skip][k//f_skip]

        # Final segmentation output
        segmentation_output = (segmentation_output >= 0.5).astype(np.int32)

        # Highlights the video based on the segmentation
        alpha = 0.5
        color = np.zeros((3,)) + [0.0, 0, 1.0]
        masked_vid = np.where(np.tile(segmentation_output, [1, 1, 3]) == 1,
                              video_cropped * (1 - alpha) + alpha * color, video_cropped)

        print('Saving Segmented Video')
        vwrite('segmented_vid.avi', (masked_vid * 255).astype(np.uint8))
def train_network(gpu_config):
    capsnet = Caps3d()

    with tf.compat.v1.Session(graph=capsnet.graph, config=gpu_config) as sess:
        tf.compat.v1.global_variables_initializer().run()

        get_num_params()
        if config.start_at_epoch <= 1:
            config.clear_output()
        else:
            capsnet.load(sess, config.save_file_name % (config.start_at_epoch - 1))
            print('Loading from epoch %d.' % (config.start_at_epoch - 1))

        n_eps_after_acc, best_loss = -1, 100000
        print('Training on UCF101')
        for ep in range(config.start_at_epoch, config.n_epochs + 1):
            print(20 * '*', 'epoch', ep, 20 * '*')
            nan_tries = 0
            while nan_tries < 3:
                # trains network for one epoch
                data_gen = TrainDataGen(config.wait_for_data, frame_skip=config.frame_skip)
                margin_loss, seg_loss, acc = capsnet.train(sess, data_gen)

                if margin_loss < 0 or acc < 0:
                    nan_tries += 1
                    # capsnet.load(sess, config.save_file_name % 20)  # loads in the previous epoch
                    # while data_gen.has_data():
                    #     data_gen.get_batch(config.batch_size)
                else:
                    config.write_output('CL: %.4f. SL: %.4f. Acc: %.4f\n' % (margin_loss, seg_loss, acc))
                    break
            if nan_tries == 3:
                print('Network cannot be trained. Too many NaN issues.')
                exit()

            if ep % config.save_every_n_epochs == 0:
                try:
                    capsnet.save(sess, config.save_file_name % ep)
                    config.write_output('Saved Network\n')
                except:
                    print('Failed to save network!!!')

            # increments the margin
            if ep % config.n_eps_for_m == 0:
                capsnet.cur_m += config.m_delta
                capsnet.cur_m = min(capsnet.cur_m, 0.9)

            # only validates after a certain number of epochs and when the training accuracy is greater than a threshold
            # this is mainly used to save time, since validation takes about 10 minutes
            if (acc >= config.acc_for_eval or n_eps_after_acc >= 0) and ep >= config.n_eps_until_eval:
                n_eps_after_acc += 1

            # validates the network
            if (acc >= config.acc_for_eval and n_eps_after_acc % config.n_eps_for_eval == 0) or ep == config.n_epochs:
                # data_gen = TestDataGen(config.wait_for_data, frame_skip=1)
                # margin_loss, seg_loss, accuracy, _ = capsnet.eval(sess, data_gen, validation=True)
                #
                # config.write_output('Validation\tCL: %.4f. SL: %.4f. Acc: %.4f.\n' %
                #                     (margin_loss, seg_loss, accuracy))
                #
                # # saves the network when validation loss in minimized
                # t_loss = margin_loss + seg_loss
                # if t_loss < best_loss:
                #     best_loss = t_loss
                try:
                    capsnet.save(sess, config.save_file_name % ep)
                    config.write_output('Saved Network\n')
                except:
                    print('Failed to save network!!!')