Example #1
0
def create_image_arrays(input_sequence, gradcams, time_mask,
                        output_folder, video_id, mask_type,
                        image_width, image_height):
    combined_images = []
    for i in range(FLAGS.seq_length):
        input_data_img = input_sequence[0, i, :, :, :]

        time_mask_copy = time_mask.copy()

        combined_img = np.concatenate((np.uint8(input_data_img),
                                       np.uint8(gradcams[i]),
                                       np.uint8(mask.perturb_sequence(
                                           input_sequence,
                                           time_mask_copy,
                                           perb_type=mask_type,
                                           snap_values=True)[0, i, :, :, :])),
                                      axis=1)

        combined_images.append(combined_img)
        cv2.imwrite(os.path.join(
            output_folder,
            "img%02d.jpg" % (i + 1)),
            combined_img)

    visualize_results_on_gradcam(combined_images,
                                 time_mask,
                                 image_width,
                                 image_height,
                                 root_dir=output_folder,
                                 case=mask_type + video_id)

    return combined_images
Example #2
0
def find_masks(dat_loader,
               model,
               config,
               lam1,
               lam2,
               N,
               ita,
               maskType="gradient",
               temporalMaskType="freeze",
               classOI=None,
               verbose=True,
               maxMaskLength=None,
               doGradCam=False,
               runTempMask=True):
    '''
    Finds masks for sequences according to the given maskMode.
    Input:
        dat_loader: iterateble providing input batches (should be val/test)
        model: model to evaulate masks on
        hyper_params: dictionary with lr, weight decay and batchsize that the grad desc method uses to find mask
        lam1: weighting factor for L1 loss
        lam2: weighting factor TV norm loss
        N: amount of iterations to run through when using grad desc method
        ita: number of times to find mask when using grad desc method (useful to eval several rand inits)
        maskMode: How to find the mask. can be one of:
            'combi': iterate through the different combinations of a coherent 'one blob' mask. Does not use grad desc
            'central': initialize an as small as possible centered mask and use grad desc to find optimal mask
            'random': initialize completely random mask and use grad desc to find optimal mask
        classOI: if only a specific class should be evaluated (must be one of the 174 class numbers)
        verbose: Print mask information during grad desc
        temporalMaskType: defines which perturb type is used to find the first mask indexes 
    '''
    model.eval()
    masks = []
    resultsPath = "results/"
    clipsTimeMaskResults = []
    clipsGradCamResults = []
    if not os.path.exists(resultsPath):
        os.makedirs(resultsPath)

    if config["splitType"] == "original":
        clips_of_interest = [
            ["person17", "boxing", "d1", "_1"],
            ["person17", "boxing", "d2", "_1"],
            ["person18", "boxing", "d3", "_1"],
            ["person18", "boxing", "d4", "_1"],
            ["person17", "handclapping", "d1", "_1"],
            ["person17", "handclapping", "d2", "_1"],
            ["person18", "handclapping", "d3", "_1"],
            ["person18", "handclapping", "d4", "_1"],
            ["person17", "handwaving", "d1", "_1"],
            ["person17", "handwaving", "d2", "_1"],
            ["person18", "handwaving", "d3", "_1"],
            ["person18", "handwaving", "d4", "_1"],
            ["person24", "jogging", "d1", "_1"],
            ["person24", "jogging", "d2", "_1"],
            ["person25", "jogging", "d3", "_1"],
            ["person25", "jogging", "d4", "_1"],
            ["person24", "running", "d1", "_1"],
            ["person24", "running", "d2", "_1"],
            ["person25", "running", "d3", "_1"],
            ["person25", "running", "d4", "_1"],
            ["person24", "walking", "d1", "_1"],
            ["person24", "walking", "d2", "_1"],
            ["person25", "walking", "d3", "_1"],
            ["person25", "walking", "d4", "_1"],
        ]
    else:
        clips_of_interest = [
            ["person07", "boxing", "d1", "_1"],
            ["person07", "boxing", "d2", "_1"],
            ["person08", "boxing", "d3", "_1"],
            ["person08", "boxing", "d4", "_1"],
            ["person07", "handclapping", "d1", "_1"],
            ["person07", "handclapping", "d2", "_1"],
            ["person08", "handclapping", "d3", "_1"],
            ["person08", "handclapping", "d4", "_1"],
            ["person07", "handwaving", "d1", "_1"],
            ["person07", "handwaving", "d2", "_1"],
            ["person08", "handwaving", "d3", "_1"],
            ["person08", "handwaving", "d4", "_1"],
            ["person09", "jogging", "d1", "_1"],
            ["person09", "jogging", "d2", "_1"],
            ["person10", "jogging", "d3", "_1"],
            ["person10", "jogging", "d4", "_1"],
            ["person09", "running", "d1", "_1"],
            ["person09", "running", "d2", "_1"],
            ["person10", "running", "d3", "_1"],
            ["person10", "running", "d4", "_1"],
            ["person09", "walking", "d1", "_1"],
            ["person09", "walking", "d2", "_1"],
            ["person10", "walking", "d3", "_1"],
            ["person10", "walking", "d4", "_1"],
        ]

    for i, (input, target, label) in enumerate(dat_loader):
        if i % 50 == 0:
            print("on idx: ", i)

        input_var = input.to(device)
        target = target.to(device)

        model.zero_grad()

        # eta is for breaking out of the grad desc early if it hasn't improved
        eta = 0.00001

        haveOutput = False
        for intraBidx in range(config["batch_size"]):

            targTag = label[intraBidx]
            tagFound = False

            for coi in clips_of_interest:
                if all([coit in targTag for coit in coi]):
                    tagFound = True

            if tagFound:

                if not haveOutput:
                    output = model(input_var)
                    haveOutput = True

                if runTempMask:
                    if config["gradCamType"] == "guessed":
                        maskTarget = torch.zeros(
                            (config["batch_size"], 1)).long()
                        maskTarget[intraBidx] = torch.argmax(output[intraBidx])

                    else:
                        maskTarget = target

                    # gradient descent for finding temporal masks
                    model.zero_grad()
                    time_mask = mask.init_mask(input_var,
                                               model,
                                               intraBidx,
                                               maskTarget,
                                               thresh=0.9,
                                               mode="central",
                                               maskPertType=temporalMaskType)
                    optimizer = torch.optim.Adam([time_mask], lr=0.2)
                    old_loss = 999999
                    for nidx in range(N):

                        if nidx % 25 == 0:
                            print("on nidx: ", nidx)

                        mask_clip = torch.sigmoid(time_mask)
                        l1loss = lam1 * torch.sum(torch.abs(mask_clip))
                        tvnorm_loss = lam2 * mask.calc_TVNorm(
                            mask_clip, p=3, q=3)

                        class_loss = model(
                            mask.perturb_sequence(input_var,
                                                  mask_clip,
                                                  perbType=temporalMaskType))

                        class_loss = class_loss[intraBidx,
                                                maskTarget[intraBidx]]

                        loss = l1loss + tvnorm_loss + class_loss

                        if abs(old_loss - loss) < eta:
                            break

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    time_mask = torch.sigmoid(time_mask)
                    score_save_path = os.path.join(
                        "cam_saved_images", args.subDir,
                        str(target[intraBidx].item()), label[intraBidx] +
                        "g_" + str(torch.argmax(output[intraBidx]).item()) +
                        "_gs%5.4f" % torch.max(output[intraBidx]).item() +
                        "_cs%5.4f" %
                        output[intraBidx][target[intraBidx]].item(),
                        "combined")

                    if not os.path.exists(score_save_path):
                        os.makedirs(score_save_path)

                    f = open(
                        score_save_path + "/ClassScoreFreezecase" +
                        label[intraBidx] + ".txt", "w+")
                    f.write(str(class_loss.item()))
                    f.close()

                    class_loss_freeze = model(
                        mask.perturb_sequence(input_var,
                                              time_mask,
                                              perbType="reverse"))
                    class_loss_freeze = class_loss_freeze[
                        intraBidx, maskTarget[intraBidx]]

                    f = open(
                        score_save_path + "/ClassScoreReversecase" +
                        label[intraBidx] + ".txt", "w+")
                    f.write(str(class_loss_freeze.item()))
                    f.close()

                    # as soon as you have the time mask, and freeze/reverse scores,
                    # Add results for current clip in list of timemask results
                    clipsTimeMaskResults.append({
                        'true_class':
                        int(target[intraBidx].item()),
                        'pred_class':
                        int(torch.argmax(output[intraBidx]).item()),
                        'video_id':
                        label[intraBidx],
                        'time_mask':
                        time_mask.detach().cpu().numpy(),
                        'original_score_guess':
                        torch.max(output[intraBidx]).item(),
                        'original_score_true':
                        output[intraBidx][target[intraBidx]].item(),
                        'freeze_score':
                        class_loss.item(),
                        'reverse_score':
                        class_loss_freeze.item()
                    })

                    if verbose:
                        print("resulting mask is: ", time_mask)

                if doGradCam:

                    target_index = maskTarget[intraBidx]

                    RESIZE_SIZE_WIDTH = 160
                    RESIZE_SIZE_HEIGHT = 120

                    grad_cam = GradCamVideo(
                        model=model.module,
                        target_layer_names=[
                            'Mixed_5c'
                        ],  # model.module.end_points and ["block5"],
                        class_dict=None,
                        use_cuda=True,
                        input_spatial_size=(RESIZE_SIZE_WIDTH,
                                            RESIZE_SIZE_HEIGHT),
                        normalizePerFrame=True,
                        archType="I3D")
                    input_to_model = input_var[intraBidx][None, :, :, :, :]

                    if config["gradCamType"] == "guessed":
                        target_index = torch.argmax(output[intraBidx])

                    mask, output_grad = grad_cam(input_to_model, target_index)

                    # as soon as you have the numpy array of gradcam heatmap, add it to list of GCheatmaps
                    clipsGradCamResults.append({
                        'true_class':
                        int(target[intraBidx].item()),
                        'pred_class':
                        int(torch.argmax(output[intraBidx]).item()),
                        'video_id':
                        label[intraBidx],
                        'GCHeatMap':
                        mask
                    })
                    '''beginning of gradcam write to disk'''
                    input_data_unnormalised = input_to_model[0].cpu().permute(
                        1, 2, 3, 0).numpy()
                    input_data_unnormalised = np.flip(input_data_unnormalised,
                                                      3)

                    targTag = label[intraBidx]

                    output_images_folder_cam_combined = os.path.join("cam_saved_images", args.subDir,
                                                                     str(target[intraBidx].item()), \
                                                                     targTag + "g_" + str(
                                                                         torch.argmax(output[intraBidx]).item()) \
                                                                     + "_gs%5.4f" % torch.max(output[intraBidx]).item() \
                                                                     + "_cs%5.4f" % output[intraBidx][
                                                                         target[intraBidx]].item(), "combined")

                    os.makedirs(output_images_folder_cam_combined,
                                exist_ok=True)

                    RESIZE_FLAG = 0

                if doGradCam and runTempMask:
                    viz.create_image_arrays(input_var, mask, time_mask, intraBidx, "freeze", output_images_folder_cam_combined,
                                            targTag, \
                                            RESIZE_FLAG, RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT)
                    viz.create_image_arrays(input_var, mask, time_mask, intraBidx, "reverse",
                                            output_images_folder_cam_combined, targTag, \
                                            RESIZE_FLAG, RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT)

                if runTempMask:
                    viz.vizualize_results(
                        input_var[intraBidx],
                        mask.perturbSequence(
                            input_var, time_mask,
                            perbType=temporalMaskType)[intraBidx],
                        time_mask,
                        rootDir=output_images_folder_cam_combined,
                        case=targTag,
                        markImgs=True,
                        iterTest=False)

                    masks.append(time_mask)

                    # finally, write pickle files to disk

    f = open(
        resultsPath + "I3d_KTH_allTimeMaskResults_original_" + args.subDir +
        ".p", "wb")
    pickle.dump(clipsTimeMaskResults, f)
    f.close()

    f = open(
        resultsPath + "I3d_KTH_allGradCamResults_original_" + args.subDir +
        ".p", "wb")
    pickle.dump(clipsGradCamResults, f)
    f.close()

    return masks
Example #3
0
def find_masks(dat_loader, model, hyper_params, lam1, lam2, N, maskType="gradient", temporalMaskType="freeze",
               classOI=None, verbose=True, maxMaskLength=None, doGradCam=False, runTempMask=True):
    '''
    Finds masks for sequences according to the given maskMode.
    Input:
        dat_loader: iterateble providing input batches (should be val/test)
        model: model to evaulate masks on
        hyper_params: dictionary with lr, weight decay and batchsize that the grad desc method uses to find mask
        lam1: weighting factor for L1 loss
        lam2: weighting factor TV norm loss
        N: amount of iterations to run through when using grad desc method
        ita: number of times to find mask when using grad desc method (useful to eval several rand inits)
        maskMode: How to find the mask. can be one of:
            'combi': iterate through the different combinations of a coherent 'one blob' mask. Does not use grad desc
            'central': initialize an as small as possible centered mask and use grad desc to find optimal mask
            'random': initialize completely random mask and use grad desc to find optimal mask
        classOI: if only specific classes should be evaluated (must be one of the 174 class numbers)
        verbose: Print mask information during grad desc
        temporalMaskType: defines which perturb type is used to find the first mask indexes 
    '''
    model.eval()
    masks = []
    df = pd.read_csv(classOI)
    results_path = 'results/'
    clips_time_mask_results = []
    clips_grad_cam_results = []
    if not os.path.exists(results_path):
        os.makedirs(results_path)

    for i, (sequence, label, video_id) in enumerate(dat_loader):
        if i % 50 == 0:
            print("on idx: ", i)

        input_var = sequence.to(device)
        label = label.to(device)

        model.zero_grad()

        # eta is for breaking out of the grad desc early if it hasn't improved
        eta = 0.00001

        for batch_index in range(hyper_params['batch_size']):

            # only look at cases where clip is of a certain class (if class of interest 'classoI' was given)
            true_class = label[batch_index].item()
            true_class_str = str(true_class)

            if (true_class_str in list(df.keys())
                and int(video_id[batch_index]) in [clip for clip in df[true_class_str]])\
                    or classOI is None:

                output = model(input_var)

                if runTempMask:
                    if hyper_params["gradCamType"] == "guessed":
                        mask_target = torch.zeros((hyper_params["batch_size"], 1)).long()
                        mask_target[batch_index] = torch.argmax(output[batch_index])

                    else:
                        mask_target = label
                        # combinatorial mask finding is straight forward, needs no grad desc
                    # gradient descent type
                    model.zero_grad()
                    time_mask = mask.init_mask(
                        input_var, model, batch_index, mask_target, threshold=0.9,
                        mode="central", mask_type=temporalMaskType)
                    optimizer = torch.optim.Adam([time_mask], lr=0.2)
                    oldLoss = 999999
                    for nidx in range(N):

                        if nidx % 25 == 0:
                            print("on nidx: ", nidx)

                        mask_clip = torch.sigmoid(time_mask)
                        l1loss = lam1 * torch.sum(torch.abs(mask_clip))
                        tvnorm_loss = lam2 * mask.calc_tv_norm(mask_clip, p=3, q=3)

                        class_loss = model(mask.perturb_sequence(
                            input_var, mask_clip, perturbation_type=temporalMaskType))

                        class_loss = class_loss[batch_index, mask_target[batch_index]]

                        loss = l1loss + tvnorm_loss + class_loss

                        if abs(oldLoss - loss) < eta:
                            break;

                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    time_mask = torch.sigmoid(time_mask)
                    pred_class = int(torch.argmax(output[batch_index]).item())
                    original_score_guess = int(torch.max(output[batch_index]).item())
                    original_score_true = output[batch_index][label[batch_index]].item()
                    video_id = video_id[batch_index]

                    score_save_path = os.path.join(
                        "cam_saved_images", args.subDir, str(true_class),
                        video_id + "g_" + str(pred_class) + "_gs%5.4f" % original_score_guess +
                        "_cs%5.4f" % original_score_true, "combined")

                    if not os.path.exists(score_save_path):
                        os.makedirs(score_save_path)

                    f = open(score_save_path + "/ClassScoreFreezecase" + video_id[batch_index] + ".txt", "w+")
                    f.write(str(class_loss.item()))
                    f.close()

                    class_loss_reverse = model(mask.perturb_sequence(input_var, time_mask, perturbation_type="reverse"))
                    class_loss_reverse = class_loss_reverse[batch_index, mask_target[batch_index]]

                    f = open(score_save_path + "/ClassScoreReversecase" + video_id[batch_index] + ".txt", "w+")
                    f.write(str(class_loss_reverse.item()))
                    f.close()

                    # as soon as you have the time mask, and freeze/reverse scores,
                    # Add results for current clip in list of timemask results
                    clips_time_mask_results.append({'true_class': true_class,
                                                    'pred_class': pred_class,
                                                    'video_id': video_id,
                                                    'time_mask': time_mask.detach().cpu().numpy(),
                                                    'original_score_guess': original_score_guess,
                                                    'original_score_true': original_score_true,
                                                    'freeze_score': class_loss.item(),
                                                    'reverse_score': class_loss_reverse.item()
                                                    })

                if doGradCam:

                    target_index = mask_target[batch_index]

                    grad_cam = GradCamVideo(model=model.module,
                                            target_layer_names=['Mixed_5c'],  # model.module.end_points and ["block5"],
                                            class_dict=None,
                                            use_cuda=True,
                                            input_spatial_size=(RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT),
                                            normalizePerFrame=True,
                                            archType="I3D")
                    input_to_model = input_var[batch_index][None, :, :, :, :]

                    if hyper_params["gradCamType"] == "guessed":
                        target_index = torch.argmax(output[batch_index])

                    mask, output_grad = grad_cam(input_to_model, target_index)

                    # as soon as you have the numpy array of gradcam heatmap, add it to list of GCheatmaps
                    clips_grad_cam_results.append(
                        {'true_class': int(label[batch_index].item()),
                        'pred_class': int(torch.argmax(output[batch_index]).item()),
                        'video_id': int(video_id[batch_index]),
                        'GCHeatMap': mask
                        })

                    '''beginning of gradcam write to disk'''
                    input_data_unnormalised = input_to_model[0].cpu().permute(1, 2, 3, 0).numpy()
                    input_data_unnormalised = np.flip(input_data_unnormalised, 3)

                    targTag = video_id[batch_index]

                    output_images_folder_cam_combined = os.path.join(
                        "cam_saved_images", args.subDir, str(label[batch_index].item()),
                        targTag + "g_" + str(torch.argmax(output[batch_index]).item()) +
                        "_gs%5.4f" % torch.max( output[batch_index]).item() +
                        "_cs%5.4f" % output[batch_index][label[batch_index]].item(),
                         "combined")

                    os.makedirs(output_images_folder_cam_combined, exist_ok=True)

                    RESIZE_FLAG = 0

                if doGradCam and runTempMask:
                    viz.create_image_arrays(
                        input_var, mask, time_mask, batch_index, "freeze", output_images_folder_cam_combined,
                        targTag, RESIZE_FLAG, RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT)
                    viz.create_image_arrays(
                        input_var, mask, time_mask, batch_index, "reverse", output_images_folder_cam_combined,
                        targTag, RESIZE_FLAG, RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT)

                    masks.append(time_mask)

    # finally, write pickle files to disk
    f = open(results_path + "allTimeMaskResults_" + args.subDir + "_" + classOI + "_" + ".p", "wb")
    pickle.dump(clips_time_mask_results, f)
    f.close()

    f = open(results_path + "allGradCamResults_" + args.subDir + "_" + classOI + "_" + ".p", "wb")
    pickle.dump(clips_grad_cam_results, f)
    f.close()

    return masks
def main(argv):

    df = pd.read_csv(FLAGS.clip_selection)  # DataFrame containing the clips to run on.

    # First we need to recreate the same variables as in the model.
    tf.reset_default_graph()
    seq_shape = (FLAGS.batch_size, FLAGS.seq_length, FLAGS.image_height, FLAGS.image_width, 3)
    seq_zeros = np.zeros(seq_shape)

    # Build graph
    graph = tf.Graph()

    # Graph for perturb_sequence(seq, mask, perb_type) method
    # Create variable to save original input sequence
    with tf.variable_scope('original_input'):
        original_input_plhdr = tf.placeholder(tf.float32, seq_shape)
        original_input_var = tf.get_variable('original_input',
                                   seq_shape,
                                   dtype=tf.float32,
                                   trainable=False)
        original_input_assign = original_input_var.assign(original_input_plhdr)

    x = tf.placeholder(tf.float32, seq_shape)

    with tf.variable_scope('mask'):
        # Create variable for the temporal mask
        mask_plhdr = tf.placeholder(tf.float32, [FLAGS.seq_length])
        mask_var = tf.get_variable('input_mask',
                                   [FLAGS.seq_length],
                                   dtype=tf.float32,
                                   trainable=True)
        mask_assign = tf.assign(mask_var, mask_plhdr)
        mask_clip = tf.nn.sigmoid(mask_var)

    with tf.variable_scope('perturb'):

        frame_inds = tf.placeholder(tf.int32, shape=(None,), name='frame_inds')

        def recurrence(last_value, current_elem):
            update_tensor = (1-mask_clip[current_elem])*original_input_var[:,current_elem,:,:,:] + \
                            mask_clip[current_elem]*last_value
            return update_tensor

        perturb_op = tf.scan(fn=recurrence,
                             elems=frame_inds,
                             initializer=original_input_var[:,0,:,:,:])
        perturb_op = tf.reshape(perturb_op, seq_shape)

    y = tf.placeholder(tf.float32, [FLAGS.batch_size, NUM_CLASSES])
    logits, clstm_3 = clstm.clstm(perturb_op, bn=False, is_training=False, num_classes=NUM_CLASSES)
    after_softmax = tf.nn.softmax(logits)


    # Settings for temporal mask method
    N = FLAGS.nb_iterations_graddescent
    maskType='gradient'
    verbose=True
    maxMaskLength=None
    do_gradcam=True
    run_temp_mask=True
    ita = 1

    variables_to_restore = {}
    for variable in tf.global_variables():
        if variable.name.startswith('mask'):
            continue
        elif variable.name.startswith('original_input'):
            continue
        else:
            # Variables need to be renamed to match with the checkpoint.
            variables_to_restore[variable.name.replace(':0','')] = variable

    with tf.Session() as sess:
        saver = tf.train.Saver(var_list=variables_to_restore)  # All but the input which is a variable
        saver.restore(sess, "/workspace/checkpoints/" + FLAGS.checkpoint_name)

        l1loss = FLAGS.lambda_1*tf.reduce_sum(tf.abs(mask_clip))
        tvnormLoss= FLAGS.lambda_2 * mask.calc_TV_norm(mask_clip, p=3, q=3)
        if FLAGS.focus_type == 'correct':
            label_index = tf.reshape(tf.argmax(y, axis=1), [])
        if FLAGS.focus_type == 'guessed':
            label_index = tf.reshape(tf.argmax(logits, axis=1), [])
        class_loss = after_softmax[:, label_index]
        # Cast as same type as l1 and TV.
        class_loss = tf.cast(class_loss, tf.float32)

        loss_function = l1loss + tvnormLoss + class_loss

        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate_start)
        train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'mask')

        with tf.variable_scope('minimize'):
            training_op = optimizer.minimize(loss_function, var_list=train_var)

        sess.run(tf.variables_initializer(optimizer.variables()))

        for ind, row in df.iterrows():

            # Get values
            subject = row['Subject']
            video_ID = row['Video_ID']
            label = row['Label']
            # Retrieve the frame paths for this clip
            paths = get_paths_for_sequence(subject, video_ID)
            # Prepare one "batch" of data (bs=1)
            input_var, label = data_for_one_sequence_5D(paths, label)

            # only look at cases where a certain class (if class of interest 'classoI' was given)
            current_class = np.argmax(label)
            print(current_class)

            if (current_class in df['Label'].values.tolist() and
                    video_ID in [derp for derp in df[df['Label']==current_class]['Video_ID'].values]):

                print("found clip of interest ", current_class, video_ID)

                preds = sess.run(logits, feed_dict={mask_var: np.zeros(FLAGS.seq_length),
                                 original_input_var: input_var,
                                 frame_inds: range(FLAGS.seq_length)})
                print('np argmax preds', np.argmax(preds))

                #  eta is for breaking out of the grad desc early if it hasn't improved
                eta = 0.00001

                have_output=False

                if preds[:, int(current_class)] < 0.1:
                    print('the guess for the correct class was less than 0.1')
                    continue

                #  Start mask optimization, should be
                #  Loss = lam1*||Mask size|| + lam2*Beta loss + class_score

                if not have_output:
                    output = sess.run(after_softmax, feed_dict={mask_var: np.zeros(FLAGS.seq_length),
                                     original_input_var: input_var,
                                     frame_inds: range(FLAGS.seq_length)})
                    have_output=True

                if run_temp_mask:
                    if maskType == 'gradient':
                        start_mask = mask.init_mask(input_var, mask_var, original_input_var,
                                              frame_inds, after_softmax, sess,
                                              label, thresh=0.9,
                                              mode="central", mask_pert_type=FLAGS.temporal_mask_type)
                        # Initialize mask variable
                        sess.run(mask_assign, {mask_plhdr: start_mask})
                        sess.run(original_input_assign,
                                 {original_input_plhdr: input_var})

                        oldLoss = 999999
                        for nidx in range(N):

                            if nidx%10==0:
                                print("on nidx: ", nidx)
                                print("mask_clipped is: ", sess.run(mask_clip))

                            _, loss_value, \
                            l1value, tvvalue, classlossvalue = sess.run([training_op,
                                                                        loss_function,
                                                                        l1loss,
                                                                        tvnormLoss,
                                                                        class_loss],
                                                                        feed_dict={y: label,
                                                                                   frame_inds: range(FLAGS.seq_length),
                                                                                   original_input_var: input_var})

                            print("LOSS: {}, l1loss: {}, tv: {}, class: {}".format(loss_value,
                                                                                   l1value,
                                                                                   tvvalue,
                                                                                   classlossvalue))
                            if abs(oldLoss-loss_value)<eta:
                                break

                        time_mask = sess.run(mask_clip)
                        save_path = os.path.join("cam_saved_images",
                                                 FLAGS.output_folder,
                                                 str(np.argmax(label)),
                                                 video_ID + "g_" + \
                                                 str(np.argmax(output)) + \
                                                 "_cs%5.4f"%output[:,np.argmax(label)] + \
                                                 "gs%5.4f"%output[:,np.argmax(output)],
                                                 "combined")

                        if not os.path.exists(save_path):
                            os.makedirs(save_path)

                        f = open(save_path+"/ClassScoreFreezecase"+video_ID+".txt","w+")
                        f.write(str(classlossvalue))
                        f.close()

                        if FLAGS.temporal_mask_type == 'reverse':

                            perturbed_sequence = mask.perturb_sequence(input_var, time_mask, perb_type='reverse')

                            class_loss_rev = sess.run(class_loss, feed_dict={mask_var: np.zeros(FLAGS.seq_length),
                                                                             original_input_var: perturbed_sequence,
                                                                             frame_inds: range(FLAGS.seq_length)})
                            f = open(save_path+"/ClassScoreReversecase" + video_ID + ".txt","w+")
                            f.write(str(class_loss_rev))
                            f.close()

                    if verbose:
                        print("resulting mask is: ", sess.run(mask_clip))

                if do_gradcam:

                    if FLAGS.focus_type== "guessed":
                        target_index=np.argmax(output)
                    if FLAGS.focus_type== "correct":
                        target_index=np.argmax(label)

                    gradcam = gc.get_gradcam(sess, logits, clstm_3, y, original_input_var, mask_var, frame_inds,
                                          input_var, label, target_index, FLAGS.image_height, FLAGS.image_weight)

                    '''beginning of gradcam write to disk'''

                    os.makedirs(save_path, exist_ok=True)

                if do_gradcam and run_temp_mask:
                    viz.create_image_arrays(input_var, gradcam, time_mask,
                                        save_path, video_ID, 'freeze',
                                        FLAGS.image_width, FLAGS.image_height)

                    if FLAGS.temporal_mask_type == 'reverse':
                        # Also create the image arrays for the reverse operation.
                        viz.create_image_arrays(input_var, gradcam, time_mask,
                                            save_path, video_ID, 'reverse',
                                            FLAGS.image_width, FLAGS.image_height)

                if run_temp_mask:
                    viz.visualize_results(input_var,
                                      mask.perturb_sequence(input_var,
                                                      time_mask, perb_type='reverse'),
                                      time_mask,
                                      root_dir=save_path,
                                      case=video_ID, mark_imgs=True,
                                      iter_test=False)
Example #5
0
def main(argv):
    df = pd.read_csv(FLAGS.clip_selection)  # DataFrame containing the clips to run on.

    # First we need to recreate the same variables as in the model.
    tf.reset_default_graph()
    seq_shape = (FLAGS.batch_size, FLAGS.seq_length, FLAGS.image_size, FLAGS.image_size, 3)
    seq_zeros = np.zeros(seq_shape)

    # Build graph
    graph = tf.Graph()

    # Graph for perturb_sequence(seq, mask, perb_type) method
    # Create variable to save original input sequence
    with tf.variable_scope('original_input'):
        original_input_plhdr = tf.placeholder(tf.float32, seq_shape)
        original_input_var = tf.get_variable('original_input',
                                             seq_shape,
                                             dtype=tf.float32,
                                             trainable=False)
        original_input_assign = original_input_var.assign(original_input_plhdr)

    x = tf.placeholder(tf.float32, seq_shape)

    with tf.variable_scope('mask'):
        # Create variable for the temporal mask
        mask_plhdr = tf.placeholder(tf.float32, [FLAGS.seq_length])
        mask_var = tf.get_variable('input_mask',
                                   [FLAGS.seq_length],
                                   dtype=tf.float32,
                                   trainable=True)
        mask_assign = tf.assign(mask_var, mask_plhdr)
        mask_clip = tf.nn.sigmoid(mask_var)

    with tf.variable_scope('perturb'):

        frame_inds = tf.placeholder(tf.int32, shape=(None,), name='frame_inds')

        # if FLAGS.temporal_mask_type == 'freeze':

        def recurrence(last_value, current_elem):
            update_tensor = (1 - mask_clip[current_elem]) * original_input_var[:, current_elem, :, :, :] + \
                            mask_clip[current_elem] * last_value
            return update_tensor

        perturb_op = tf.scan(fn=recurrence,
                             elems=frame_inds,
                             initializer=original_input_var[:, 0, :, :, :])
        perturb_op = tf.reshape(perturb_op, seq_shape)

    y = tf.placeholder(tf.float32, [FLAGS.batch_size, NUM_CLASSES])
    logits, clstm_3 = clstm(perturb_op)
    after_softmax = tf.nn.softmax(logits)

    validation_dataset = create_dataset(FLAGS.val_data)

    # Re-initializable iterator
    iterator = tf.data.Iterator.from_structure(
        validation_dataset.output_types, validation_dataset.output_shapes)
    next_element = iterator.get_next()

    validation_init_op = iterator.make_initializer(validation_dataset,
                                                   name='val_init_op')

    if not os.path.exists(FLAGS.output_folder):
        os.makedirs(FLAGS.output_folder)

    STEPS_VAL = int(FLAGS.nb_val_samples / FLAGS.batch_size)
    # Settings for temporal mask method
    N = FLAGS.nb_iterations_graddescent
    mask_type = 'gradient'
    verbose = True
    maxMaskLength = None
    do_gradcam = True
    run_temp_mask = True
    ita = 1

    variables_to_restore = {}
    for variable in tf.global_variables():
        if variable.name.startswith('mask'):
            continue
        elif variable.name.startswith('original_input'):
            continue
        else:
            # Variables need to be renamed to match with the checkpoint.
            variables_to_restore[variable.name.replace(':0', '')] = variable

    with tf.Session() as sess:
        saver = tf.train.Saver(var_list=variables_to_restore)  # All but the input which is a variable
        saver.restore(sess, "/workspace/checkpoints/3lyr_32_mom_wholeseq_bs8")

        sess.run(validation_init_op)

        l1loss = FLAGS.lambda_1 * tf.reduce_sum(tf.abs(mask_clip))
        tvnormLoss = FLAGS.lambda_2 * mask.calc_TV_norm(mask_clip, p=3, q=3)
        if FLAGS.focus_type == 'correct':
            label_index = tf.reshape(tf.argmax(y, axis=1), [])
        if FLAGS.focus_type == 'guessed':
            label_index = tf.reshape(tf.argmax(logits, axis=1), [])
        class_loss = after_softmax[:, label_index]
        # Cast as same type as l1 and TV.
        class_loss = tf.cast(class_loss, tf.float32)

        loss_function = l1loss + tvnormLoss + class_loss

        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate_start)
        train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'mask')

        with tf.variable_scope('minimize'):
            training_op = optimizer.minimize(loss_function, var_list=train_var)

        sess.run(tf.variables_initializer(optimizer.variables()))

        for i in range(STEPS_VAL):

            input_var, label, video_id = sess.run(next_element)

            # video_id is returned as an array of a bytes object
            video_id = video_id[0].decode("utf-8")  # like so: array([b'74225'], dtype=object)

            # only look at cases where the network was correct, and is of a certain class (if class of interest 'classoI' was given)
            current_class = np.argmax(label)
            print(current_class)

            current_class = str(np.argmax(label))

            if current_class in list(df.keys()) and int(video_id) in [derp for derp in df[current_class]]:

                print("found clip of interest ", current_class, video_id)

                preds = sess.run(logits, feed_dict={mask_var: np.zeros((FLAGS.seq_length)),
                                                    original_input_var: input_var,
                                                    frame_inds: range(FLAGS.seq_length)})
                print('np argmax preds', np.argmax(preds))

                # eta is for breaking out of the grad desc early if it hasn't improved
                eta = 0.00001

                have_output = False

                if preds[:, int(current_class)] < 0.1:
                    print('the guess for the correct class was less than 0.1')
                    continue

                if not have_output:
                    output = sess.run(after_softmax, feed_dict={mask_var: np.zeros((FLAGS.seq_length)),
                                                                original_input_var: input_var,
                                                                frame_inds: range(FLAGS.seq_length)})
                    have_output = True

                if run_temp_mask:
                    if mask_type == 'gradient':
                        start_mask = mask.init_mask(input_var, mask_var, original_input_var,
                                                    frame_inds, after_softmax, sess,
                                                    label, thresh=0.9,
                                                    mode="central", mask_pert_type=FLAGS.temporal_mask_type)
                        # Initialize mask variable
                        sess.run(mask_assign, {mask_plhdr: start_mask})
                        sess.run(original_input_assign,
                                 {original_input_plhdr: input_var})

                        oldLoss = 999999
                        for nidx in range(N):

                            if (nidx % 10 == 0):
                                print("on nidx: ", nidx)
                                print("mask_clipped is: ", sess.run(mask_clip))

                            _, loss_value, \
                            l1value, tvvalue, classlossvalue = sess.run([training_op,
                                                                         loss_function,
                                                                         l1loss,
                                                                         tvnormLoss,
                                                                         class_loss],
                                                                        feed_dict={y: label,
                                                                                   frame_inds: range(FLAGS.seq_length),
                                                                                   original_input_var: input_var})

                            print("LOSS: {}, l1loss: {}, tv: {}, class: {}".format(loss_value,
                                                                                   l1value,
                                                                                   tvvalue,
                                                                                   classlossvalue))
                            if (abs(oldLoss - loss_value) < eta):
                                break;

                        time_mask = sess.run(mask_clip)
                        save_path = os.path.join("cam_saved_images",
                                                 FLAGS.output_folder,
                                                 str(np.argmax(label)),
                                                 video_id + "g_" + \
                                                 str(np.argmax(output)) + \
                                                 "_cs%5.4f" % output[:, np.argmax(label)] + \
                                                 "gs%5.4f" % output[:, np.argmax(output)],
                                                 "combined")

                        if not os.path.exists(save_path):
                            os.makedirs(save_path)

                        f = open(save_path + "/ClassScoreFreezecase" + video_id + ".txt", "w+")
                        f.write(str(classlossvalue))
                        f.close()

                        if FLAGS.temporal_mask_type == 'reverse':
                            perturbed_sequence = mask.perturb_sequence(input_var, time_mask, perb_type='reverse')

                            class_loss_rev = sess.run(class_loss, feed_dict={mask_var: np.zeros((FLAGS.seq_length)),
                                                                             original_input_var: perturbed_sequence,
                                                                             frame_inds: range(FLAGS.seq_length)})
                            f = open(save_path + "/ClassScoreReversecase" + video_id + ".txt", "w+")
                            f.write(str(class_loss_rev))
                            f.close()

                    if (verbose):
                        print("resulting mask is: ", sess.run(mask_clip))

                if (do_gradcam):

                    if (FLAGS.focus_type == "guessed"):
                        target_index = np.argmax(output)
                    if (FLAGS.focus_type == "correct"):
                        target_index = np.argmax(label)

                    gradcam = gc.get_gradcam(sess, logits, clstm_3, y,
                                             original_input_var, mask_var, frame_inds,
                                             input_var, label, target_index,
                                             FLAGS.image_size, FLAGS.image_size)

                    '''beginning of gradcam write to disk'''

                    os.makedirs(save_path, exist_ok=True)

                if do_gradcam and run_temp_mask:
                    viz.create_image_arrays(
                        input_var, gradcam, time_mask,
                        save_path, video_id, 'freeze',
                        FLAGS.image_size, FLAGS.image_size)

                    if FLAGS.temporal_mask_type == 'reverse':
                        # Also create the image arrays for the reverse operation.
                        viz.create_image_arrays(
                            input_var, gradcam, time_mask,
                            save_path, video_id, 'reverse',
                            FLAGS.image_size, FLAGS.image_size)

                if run_temp_mask:
                    viz.visualize_results(
                        input_var,
                        mask.perturb_sequence(
                            input_var, time_mask, perb_type='reverse'),
                        time_mask,
                        root_dir=save_path,
                        case=video_id, mark_imgs=True,
                        iter_test=False)