def create_image_arrays(input_sequence, gradcams, time_mask, output_folder, video_id, mask_type, image_width, image_height): combined_images = [] for i in range(FLAGS.seq_length): input_data_img = input_sequence[0, i, :, :, :] time_mask_copy = time_mask.copy() combined_img = np.concatenate((np.uint8(input_data_img), np.uint8(gradcams[i]), np.uint8(mask.perturb_sequence( input_sequence, time_mask_copy, perb_type=mask_type, snap_values=True)[0, i, :, :, :])), axis=1) combined_images.append(combined_img) cv2.imwrite(os.path.join( output_folder, "img%02d.jpg" % (i + 1)), combined_img) visualize_results_on_gradcam(combined_images, time_mask, image_width, image_height, root_dir=output_folder, case=mask_type + video_id) return combined_images
def find_masks(dat_loader, model, config, lam1, lam2, N, ita, maskType="gradient", temporalMaskType="freeze", classOI=None, verbose=True, maxMaskLength=None, doGradCam=False, runTempMask=True): ''' Finds masks for sequences according to the given maskMode. Input: dat_loader: iterateble providing input batches (should be val/test) model: model to evaulate masks on hyper_params: dictionary with lr, weight decay and batchsize that the grad desc method uses to find mask lam1: weighting factor for L1 loss lam2: weighting factor TV norm loss N: amount of iterations to run through when using grad desc method ita: number of times to find mask when using grad desc method (useful to eval several rand inits) maskMode: How to find the mask. can be one of: 'combi': iterate through the different combinations of a coherent 'one blob' mask. Does not use grad desc 'central': initialize an as small as possible centered mask and use grad desc to find optimal mask 'random': initialize completely random mask and use grad desc to find optimal mask classOI: if only a specific class should be evaluated (must be one of the 174 class numbers) verbose: Print mask information during grad desc temporalMaskType: defines which perturb type is used to find the first mask indexes ''' model.eval() masks = [] resultsPath = "results/" clipsTimeMaskResults = [] clipsGradCamResults = [] if not os.path.exists(resultsPath): os.makedirs(resultsPath) if config["splitType"] == "original": clips_of_interest = [ ["person17", "boxing", "d1", "_1"], ["person17", "boxing", "d2", "_1"], ["person18", "boxing", "d3", "_1"], ["person18", "boxing", "d4", "_1"], ["person17", "handclapping", "d1", "_1"], ["person17", "handclapping", "d2", "_1"], ["person18", "handclapping", "d3", "_1"], ["person18", "handclapping", "d4", "_1"], ["person17", "handwaving", "d1", "_1"], ["person17", "handwaving", "d2", "_1"], ["person18", "handwaving", "d3", "_1"], ["person18", "handwaving", "d4", "_1"], ["person24", "jogging", "d1", "_1"], ["person24", "jogging", "d2", "_1"], ["person25", "jogging", "d3", "_1"], ["person25", "jogging", "d4", "_1"], ["person24", "running", "d1", "_1"], ["person24", "running", "d2", "_1"], ["person25", "running", "d3", "_1"], ["person25", "running", "d4", "_1"], ["person24", "walking", "d1", "_1"], ["person24", "walking", "d2", "_1"], ["person25", "walking", "d3", "_1"], ["person25", "walking", "d4", "_1"], ] else: clips_of_interest = [ ["person07", "boxing", "d1", "_1"], ["person07", "boxing", "d2", "_1"], ["person08", "boxing", "d3", "_1"], ["person08", "boxing", "d4", "_1"], ["person07", "handclapping", "d1", "_1"], ["person07", "handclapping", "d2", "_1"], ["person08", "handclapping", "d3", "_1"], ["person08", "handclapping", "d4", "_1"], ["person07", "handwaving", "d1", "_1"], ["person07", "handwaving", "d2", "_1"], ["person08", "handwaving", "d3", "_1"], ["person08", "handwaving", "d4", "_1"], ["person09", "jogging", "d1", "_1"], ["person09", "jogging", "d2", "_1"], ["person10", "jogging", "d3", "_1"], ["person10", "jogging", "d4", "_1"], ["person09", "running", "d1", "_1"], ["person09", "running", "d2", "_1"], ["person10", "running", "d3", "_1"], ["person10", "running", "d4", "_1"], ["person09", "walking", "d1", "_1"], ["person09", "walking", "d2", "_1"], ["person10", "walking", "d3", "_1"], ["person10", "walking", "d4", "_1"], ] for i, (input, target, label) in enumerate(dat_loader): if i % 50 == 0: print("on idx: ", i) input_var = input.to(device) target = target.to(device) model.zero_grad() # eta is for breaking out of the grad desc early if it hasn't improved eta = 0.00001 haveOutput = False for intraBidx in range(config["batch_size"]): targTag = label[intraBidx] tagFound = False for coi in clips_of_interest: if all([coit in targTag for coit in coi]): tagFound = True if tagFound: if not haveOutput: output = model(input_var) haveOutput = True if runTempMask: if config["gradCamType"] == "guessed": maskTarget = torch.zeros( (config["batch_size"], 1)).long() maskTarget[intraBidx] = torch.argmax(output[intraBidx]) else: maskTarget = target # gradient descent for finding temporal masks model.zero_grad() time_mask = mask.init_mask(input_var, model, intraBidx, maskTarget, thresh=0.9, mode="central", maskPertType=temporalMaskType) optimizer = torch.optim.Adam([time_mask], lr=0.2) old_loss = 999999 for nidx in range(N): if nidx % 25 == 0: print("on nidx: ", nidx) mask_clip = torch.sigmoid(time_mask) l1loss = lam1 * torch.sum(torch.abs(mask_clip)) tvnorm_loss = lam2 * mask.calc_TVNorm( mask_clip, p=3, q=3) class_loss = model( mask.perturb_sequence(input_var, mask_clip, perbType=temporalMaskType)) class_loss = class_loss[intraBidx, maskTarget[intraBidx]] loss = l1loss + tvnorm_loss + class_loss if abs(old_loss - loss) < eta: break optimizer.zero_grad() loss.backward() optimizer.step() time_mask = torch.sigmoid(time_mask) score_save_path = os.path.join( "cam_saved_images", args.subDir, str(target[intraBidx].item()), label[intraBidx] + "g_" + str(torch.argmax(output[intraBidx]).item()) + "_gs%5.4f" % torch.max(output[intraBidx]).item() + "_cs%5.4f" % output[intraBidx][target[intraBidx]].item(), "combined") if not os.path.exists(score_save_path): os.makedirs(score_save_path) f = open( score_save_path + "/ClassScoreFreezecase" + label[intraBidx] + ".txt", "w+") f.write(str(class_loss.item())) f.close() class_loss_freeze = model( mask.perturb_sequence(input_var, time_mask, perbType="reverse")) class_loss_freeze = class_loss_freeze[ intraBidx, maskTarget[intraBidx]] f = open( score_save_path + "/ClassScoreReversecase" + label[intraBidx] + ".txt", "w+") f.write(str(class_loss_freeze.item())) f.close() # as soon as you have the time mask, and freeze/reverse scores, # Add results for current clip in list of timemask results clipsTimeMaskResults.append({ 'true_class': int(target[intraBidx].item()), 'pred_class': int(torch.argmax(output[intraBidx]).item()), 'video_id': label[intraBidx], 'time_mask': time_mask.detach().cpu().numpy(), 'original_score_guess': torch.max(output[intraBidx]).item(), 'original_score_true': output[intraBidx][target[intraBidx]].item(), 'freeze_score': class_loss.item(), 'reverse_score': class_loss_freeze.item() }) if verbose: print("resulting mask is: ", time_mask) if doGradCam: target_index = maskTarget[intraBidx] RESIZE_SIZE_WIDTH = 160 RESIZE_SIZE_HEIGHT = 120 grad_cam = GradCamVideo( model=model.module, target_layer_names=[ 'Mixed_5c' ], # model.module.end_points and ["block5"], class_dict=None, use_cuda=True, input_spatial_size=(RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT), normalizePerFrame=True, archType="I3D") input_to_model = input_var[intraBidx][None, :, :, :, :] if config["gradCamType"] == "guessed": target_index = torch.argmax(output[intraBidx]) mask, output_grad = grad_cam(input_to_model, target_index) # as soon as you have the numpy array of gradcam heatmap, add it to list of GCheatmaps clipsGradCamResults.append({ 'true_class': int(target[intraBidx].item()), 'pred_class': int(torch.argmax(output[intraBidx]).item()), 'video_id': label[intraBidx], 'GCHeatMap': mask }) '''beginning of gradcam write to disk''' input_data_unnormalised = input_to_model[0].cpu().permute( 1, 2, 3, 0).numpy() input_data_unnormalised = np.flip(input_data_unnormalised, 3) targTag = label[intraBidx] output_images_folder_cam_combined = os.path.join("cam_saved_images", args.subDir, str(target[intraBidx].item()), \ targTag + "g_" + str( torch.argmax(output[intraBidx]).item()) \ + "_gs%5.4f" % torch.max(output[intraBidx]).item() \ + "_cs%5.4f" % output[intraBidx][ target[intraBidx]].item(), "combined") os.makedirs(output_images_folder_cam_combined, exist_ok=True) RESIZE_FLAG = 0 if doGradCam and runTempMask: viz.create_image_arrays(input_var, mask, time_mask, intraBidx, "freeze", output_images_folder_cam_combined, targTag, \ RESIZE_FLAG, RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT) viz.create_image_arrays(input_var, mask, time_mask, intraBidx, "reverse", output_images_folder_cam_combined, targTag, \ RESIZE_FLAG, RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT) if runTempMask: viz.vizualize_results( input_var[intraBidx], mask.perturbSequence( input_var, time_mask, perbType=temporalMaskType)[intraBidx], time_mask, rootDir=output_images_folder_cam_combined, case=targTag, markImgs=True, iterTest=False) masks.append(time_mask) # finally, write pickle files to disk f = open( resultsPath + "I3d_KTH_allTimeMaskResults_original_" + args.subDir + ".p", "wb") pickle.dump(clipsTimeMaskResults, f) f.close() f = open( resultsPath + "I3d_KTH_allGradCamResults_original_" + args.subDir + ".p", "wb") pickle.dump(clipsGradCamResults, f) f.close() return masks
def find_masks(dat_loader, model, hyper_params, lam1, lam2, N, maskType="gradient", temporalMaskType="freeze", classOI=None, verbose=True, maxMaskLength=None, doGradCam=False, runTempMask=True): ''' Finds masks for sequences according to the given maskMode. Input: dat_loader: iterateble providing input batches (should be val/test) model: model to evaulate masks on hyper_params: dictionary with lr, weight decay and batchsize that the grad desc method uses to find mask lam1: weighting factor for L1 loss lam2: weighting factor TV norm loss N: amount of iterations to run through when using grad desc method ita: number of times to find mask when using grad desc method (useful to eval several rand inits) maskMode: How to find the mask. can be one of: 'combi': iterate through the different combinations of a coherent 'one blob' mask. Does not use grad desc 'central': initialize an as small as possible centered mask and use grad desc to find optimal mask 'random': initialize completely random mask and use grad desc to find optimal mask classOI: if only specific classes should be evaluated (must be one of the 174 class numbers) verbose: Print mask information during grad desc temporalMaskType: defines which perturb type is used to find the first mask indexes ''' model.eval() masks = [] df = pd.read_csv(classOI) results_path = 'results/' clips_time_mask_results = [] clips_grad_cam_results = [] if not os.path.exists(results_path): os.makedirs(results_path) for i, (sequence, label, video_id) in enumerate(dat_loader): if i % 50 == 0: print("on idx: ", i) input_var = sequence.to(device) label = label.to(device) model.zero_grad() # eta is for breaking out of the grad desc early if it hasn't improved eta = 0.00001 for batch_index in range(hyper_params['batch_size']): # only look at cases where clip is of a certain class (if class of interest 'classoI' was given) true_class = label[batch_index].item() true_class_str = str(true_class) if (true_class_str in list(df.keys()) and int(video_id[batch_index]) in [clip for clip in df[true_class_str]])\ or classOI is None: output = model(input_var) if runTempMask: if hyper_params["gradCamType"] == "guessed": mask_target = torch.zeros((hyper_params["batch_size"], 1)).long() mask_target[batch_index] = torch.argmax(output[batch_index]) else: mask_target = label # combinatorial mask finding is straight forward, needs no grad desc # gradient descent type model.zero_grad() time_mask = mask.init_mask( input_var, model, batch_index, mask_target, threshold=0.9, mode="central", mask_type=temporalMaskType) optimizer = torch.optim.Adam([time_mask], lr=0.2) oldLoss = 999999 for nidx in range(N): if nidx % 25 == 0: print("on nidx: ", nidx) mask_clip = torch.sigmoid(time_mask) l1loss = lam1 * torch.sum(torch.abs(mask_clip)) tvnorm_loss = lam2 * mask.calc_tv_norm(mask_clip, p=3, q=3) class_loss = model(mask.perturb_sequence( input_var, mask_clip, perturbation_type=temporalMaskType)) class_loss = class_loss[batch_index, mask_target[batch_index]] loss = l1loss + tvnorm_loss + class_loss if abs(oldLoss - loss) < eta: break; optimizer.zero_grad() loss.backward() optimizer.step() time_mask = torch.sigmoid(time_mask) pred_class = int(torch.argmax(output[batch_index]).item()) original_score_guess = int(torch.max(output[batch_index]).item()) original_score_true = output[batch_index][label[batch_index]].item() video_id = video_id[batch_index] score_save_path = os.path.join( "cam_saved_images", args.subDir, str(true_class), video_id + "g_" + str(pred_class) + "_gs%5.4f" % original_score_guess + "_cs%5.4f" % original_score_true, "combined") if not os.path.exists(score_save_path): os.makedirs(score_save_path) f = open(score_save_path + "/ClassScoreFreezecase" + video_id[batch_index] + ".txt", "w+") f.write(str(class_loss.item())) f.close() class_loss_reverse = model(mask.perturb_sequence(input_var, time_mask, perturbation_type="reverse")) class_loss_reverse = class_loss_reverse[batch_index, mask_target[batch_index]] f = open(score_save_path + "/ClassScoreReversecase" + video_id[batch_index] + ".txt", "w+") f.write(str(class_loss_reverse.item())) f.close() # as soon as you have the time mask, and freeze/reverse scores, # Add results for current clip in list of timemask results clips_time_mask_results.append({'true_class': true_class, 'pred_class': pred_class, 'video_id': video_id, 'time_mask': time_mask.detach().cpu().numpy(), 'original_score_guess': original_score_guess, 'original_score_true': original_score_true, 'freeze_score': class_loss.item(), 'reverse_score': class_loss_reverse.item() }) if doGradCam: target_index = mask_target[batch_index] grad_cam = GradCamVideo(model=model.module, target_layer_names=['Mixed_5c'], # model.module.end_points and ["block5"], class_dict=None, use_cuda=True, input_spatial_size=(RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT), normalizePerFrame=True, archType="I3D") input_to_model = input_var[batch_index][None, :, :, :, :] if hyper_params["gradCamType"] == "guessed": target_index = torch.argmax(output[batch_index]) mask, output_grad = grad_cam(input_to_model, target_index) # as soon as you have the numpy array of gradcam heatmap, add it to list of GCheatmaps clips_grad_cam_results.append( {'true_class': int(label[batch_index].item()), 'pred_class': int(torch.argmax(output[batch_index]).item()), 'video_id': int(video_id[batch_index]), 'GCHeatMap': mask }) '''beginning of gradcam write to disk''' input_data_unnormalised = input_to_model[0].cpu().permute(1, 2, 3, 0).numpy() input_data_unnormalised = np.flip(input_data_unnormalised, 3) targTag = video_id[batch_index] output_images_folder_cam_combined = os.path.join( "cam_saved_images", args.subDir, str(label[batch_index].item()), targTag + "g_" + str(torch.argmax(output[batch_index]).item()) + "_gs%5.4f" % torch.max( output[batch_index]).item() + "_cs%5.4f" % output[batch_index][label[batch_index]].item(), "combined") os.makedirs(output_images_folder_cam_combined, exist_ok=True) RESIZE_FLAG = 0 if doGradCam and runTempMask: viz.create_image_arrays( input_var, mask, time_mask, batch_index, "freeze", output_images_folder_cam_combined, targTag, RESIZE_FLAG, RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT) viz.create_image_arrays( input_var, mask, time_mask, batch_index, "reverse", output_images_folder_cam_combined, targTag, RESIZE_FLAG, RESIZE_SIZE_WIDTH, RESIZE_SIZE_HEIGHT) masks.append(time_mask) # finally, write pickle files to disk f = open(results_path + "allTimeMaskResults_" + args.subDir + "_" + classOI + "_" + ".p", "wb") pickle.dump(clips_time_mask_results, f) f.close() f = open(results_path + "allGradCamResults_" + args.subDir + "_" + classOI + "_" + ".p", "wb") pickle.dump(clips_grad_cam_results, f) f.close() return masks
def main(argv): df = pd.read_csv(FLAGS.clip_selection) # DataFrame containing the clips to run on. # First we need to recreate the same variables as in the model. tf.reset_default_graph() seq_shape = (FLAGS.batch_size, FLAGS.seq_length, FLAGS.image_height, FLAGS.image_width, 3) seq_zeros = np.zeros(seq_shape) # Build graph graph = tf.Graph() # Graph for perturb_sequence(seq, mask, perb_type) method # Create variable to save original input sequence with tf.variable_scope('original_input'): original_input_plhdr = tf.placeholder(tf.float32, seq_shape) original_input_var = tf.get_variable('original_input', seq_shape, dtype=tf.float32, trainable=False) original_input_assign = original_input_var.assign(original_input_plhdr) x = tf.placeholder(tf.float32, seq_shape) with tf.variable_scope('mask'): # Create variable for the temporal mask mask_plhdr = tf.placeholder(tf.float32, [FLAGS.seq_length]) mask_var = tf.get_variable('input_mask', [FLAGS.seq_length], dtype=tf.float32, trainable=True) mask_assign = tf.assign(mask_var, mask_plhdr) mask_clip = tf.nn.sigmoid(mask_var) with tf.variable_scope('perturb'): frame_inds = tf.placeholder(tf.int32, shape=(None,), name='frame_inds') def recurrence(last_value, current_elem): update_tensor = (1-mask_clip[current_elem])*original_input_var[:,current_elem,:,:,:] + \ mask_clip[current_elem]*last_value return update_tensor perturb_op = tf.scan(fn=recurrence, elems=frame_inds, initializer=original_input_var[:,0,:,:,:]) perturb_op = tf.reshape(perturb_op, seq_shape) y = tf.placeholder(tf.float32, [FLAGS.batch_size, NUM_CLASSES]) logits, clstm_3 = clstm.clstm(perturb_op, bn=False, is_training=False, num_classes=NUM_CLASSES) after_softmax = tf.nn.softmax(logits) # Settings for temporal mask method N = FLAGS.nb_iterations_graddescent maskType='gradient' verbose=True maxMaskLength=None do_gradcam=True run_temp_mask=True ita = 1 variables_to_restore = {} for variable in tf.global_variables(): if variable.name.startswith('mask'): continue elif variable.name.startswith('original_input'): continue else: # Variables need to be renamed to match with the checkpoint. variables_to_restore[variable.name.replace(':0','')] = variable with tf.Session() as sess: saver = tf.train.Saver(var_list=variables_to_restore) # All but the input which is a variable saver.restore(sess, "/workspace/checkpoints/" + FLAGS.checkpoint_name) l1loss = FLAGS.lambda_1*tf.reduce_sum(tf.abs(mask_clip)) tvnormLoss= FLAGS.lambda_2 * mask.calc_TV_norm(mask_clip, p=3, q=3) if FLAGS.focus_type == 'correct': label_index = tf.reshape(tf.argmax(y, axis=1), []) if FLAGS.focus_type == 'guessed': label_index = tf.reshape(tf.argmax(logits, axis=1), []) class_loss = after_softmax[:, label_index] # Cast as same type as l1 and TV. class_loss = tf.cast(class_loss, tf.float32) loss_function = l1loss + tvnormLoss + class_loss optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate_start) train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'mask') with tf.variable_scope('minimize'): training_op = optimizer.minimize(loss_function, var_list=train_var) sess.run(tf.variables_initializer(optimizer.variables())) for ind, row in df.iterrows(): # Get values subject = row['Subject'] video_ID = row['Video_ID'] label = row['Label'] # Retrieve the frame paths for this clip paths = get_paths_for_sequence(subject, video_ID) # Prepare one "batch" of data (bs=1) input_var, label = data_for_one_sequence_5D(paths, label) # only look at cases where a certain class (if class of interest 'classoI' was given) current_class = np.argmax(label) print(current_class) if (current_class in df['Label'].values.tolist() and video_ID in [derp for derp in df[df['Label']==current_class]['Video_ID'].values]): print("found clip of interest ", current_class, video_ID) preds = sess.run(logits, feed_dict={mask_var: np.zeros(FLAGS.seq_length), original_input_var: input_var, frame_inds: range(FLAGS.seq_length)}) print('np argmax preds', np.argmax(preds)) # eta is for breaking out of the grad desc early if it hasn't improved eta = 0.00001 have_output=False if preds[:, int(current_class)] < 0.1: print('the guess for the correct class was less than 0.1') continue # Start mask optimization, should be # Loss = lam1*||Mask size|| + lam2*Beta loss + class_score if not have_output: output = sess.run(after_softmax, feed_dict={mask_var: np.zeros(FLAGS.seq_length), original_input_var: input_var, frame_inds: range(FLAGS.seq_length)}) have_output=True if run_temp_mask: if maskType == 'gradient': start_mask = mask.init_mask(input_var, mask_var, original_input_var, frame_inds, after_softmax, sess, label, thresh=0.9, mode="central", mask_pert_type=FLAGS.temporal_mask_type) # Initialize mask variable sess.run(mask_assign, {mask_plhdr: start_mask}) sess.run(original_input_assign, {original_input_plhdr: input_var}) oldLoss = 999999 for nidx in range(N): if nidx%10==0: print("on nidx: ", nidx) print("mask_clipped is: ", sess.run(mask_clip)) _, loss_value, \ l1value, tvvalue, classlossvalue = sess.run([training_op, loss_function, l1loss, tvnormLoss, class_loss], feed_dict={y: label, frame_inds: range(FLAGS.seq_length), original_input_var: input_var}) print("LOSS: {}, l1loss: {}, tv: {}, class: {}".format(loss_value, l1value, tvvalue, classlossvalue)) if abs(oldLoss-loss_value)<eta: break time_mask = sess.run(mask_clip) save_path = os.path.join("cam_saved_images", FLAGS.output_folder, str(np.argmax(label)), video_ID + "g_" + \ str(np.argmax(output)) + \ "_cs%5.4f"%output[:,np.argmax(label)] + \ "gs%5.4f"%output[:,np.argmax(output)], "combined") if not os.path.exists(save_path): os.makedirs(save_path) f = open(save_path+"/ClassScoreFreezecase"+video_ID+".txt","w+") f.write(str(classlossvalue)) f.close() if FLAGS.temporal_mask_type == 'reverse': perturbed_sequence = mask.perturb_sequence(input_var, time_mask, perb_type='reverse') class_loss_rev = sess.run(class_loss, feed_dict={mask_var: np.zeros(FLAGS.seq_length), original_input_var: perturbed_sequence, frame_inds: range(FLAGS.seq_length)}) f = open(save_path+"/ClassScoreReversecase" + video_ID + ".txt","w+") f.write(str(class_loss_rev)) f.close() if verbose: print("resulting mask is: ", sess.run(mask_clip)) if do_gradcam: if FLAGS.focus_type== "guessed": target_index=np.argmax(output) if FLAGS.focus_type== "correct": target_index=np.argmax(label) gradcam = gc.get_gradcam(sess, logits, clstm_3, y, original_input_var, mask_var, frame_inds, input_var, label, target_index, FLAGS.image_height, FLAGS.image_weight) '''beginning of gradcam write to disk''' os.makedirs(save_path, exist_ok=True) if do_gradcam and run_temp_mask: viz.create_image_arrays(input_var, gradcam, time_mask, save_path, video_ID, 'freeze', FLAGS.image_width, FLAGS.image_height) if FLAGS.temporal_mask_type == 'reverse': # Also create the image arrays for the reverse operation. viz.create_image_arrays(input_var, gradcam, time_mask, save_path, video_ID, 'reverse', FLAGS.image_width, FLAGS.image_height) if run_temp_mask: viz.visualize_results(input_var, mask.perturb_sequence(input_var, time_mask, perb_type='reverse'), time_mask, root_dir=save_path, case=video_ID, mark_imgs=True, iter_test=False)
def main(argv): df = pd.read_csv(FLAGS.clip_selection) # DataFrame containing the clips to run on. # First we need to recreate the same variables as in the model. tf.reset_default_graph() seq_shape = (FLAGS.batch_size, FLAGS.seq_length, FLAGS.image_size, FLAGS.image_size, 3) seq_zeros = np.zeros(seq_shape) # Build graph graph = tf.Graph() # Graph for perturb_sequence(seq, mask, perb_type) method # Create variable to save original input sequence with tf.variable_scope('original_input'): original_input_plhdr = tf.placeholder(tf.float32, seq_shape) original_input_var = tf.get_variable('original_input', seq_shape, dtype=tf.float32, trainable=False) original_input_assign = original_input_var.assign(original_input_plhdr) x = tf.placeholder(tf.float32, seq_shape) with tf.variable_scope('mask'): # Create variable for the temporal mask mask_plhdr = tf.placeholder(tf.float32, [FLAGS.seq_length]) mask_var = tf.get_variable('input_mask', [FLAGS.seq_length], dtype=tf.float32, trainable=True) mask_assign = tf.assign(mask_var, mask_plhdr) mask_clip = tf.nn.sigmoid(mask_var) with tf.variable_scope('perturb'): frame_inds = tf.placeholder(tf.int32, shape=(None,), name='frame_inds') # if FLAGS.temporal_mask_type == 'freeze': def recurrence(last_value, current_elem): update_tensor = (1 - mask_clip[current_elem]) * original_input_var[:, current_elem, :, :, :] + \ mask_clip[current_elem] * last_value return update_tensor perturb_op = tf.scan(fn=recurrence, elems=frame_inds, initializer=original_input_var[:, 0, :, :, :]) perturb_op = tf.reshape(perturb_op, seq_shape) y = tf.placeholder(tf.float32, [FLAGS.batch_size, NUM_CLASSES]) logits, clstm_3 = clstm(perturb_op) after_softmax = tf.nn.softmax(logits) validation_dataset = create_dataset(FLAGS.val_data) # Re-initializable iterator iterator = tf.data.Iterator.from_structure( validation_dataset.output_types, validation_dataset.output_shapes) next_element = iterator.get_next() validation_init_op = iterator.make_initializer(validation_dataset, name='val_init_op') if not os.path.exists(FLAGS.output_folder): os.makedirs(FLAGS.output_folder) STEPS_VAL = int(FLAGS.nb_val_samples / FLAGS.batch_size) # Settings for temporal mask method N = FLAGS.nb_iterations_graddescent mask_type = 'gradient' verbose = True maxMaskLength = None do_gradcam = True run_temp_mask = True ita = 1 variables_to_restore = {} for variable in tf.global_variables(): if variable.name.startswith('mask'): continue elif variable.name.startswith('original_input'): continue else: # Variables need to be renamed to match with the checkpoint. variables_to_restore[variable.name.replace(':0', '')] = variable with tf.Session() as sess: saver = tf.train.Saver(var_list=variables_to_restore) # All but the input which is a variable saver.restore(sess, "/workspace/checkpoints/3lyr_32_mom_wholeseq_bs8") sess.run(validation_init_op) l1loss = FLAGS.lambda_1 * tf.reduce_sum(tf.abs(mask_clip)) tvnormLoss = FLAGS.lambda_2 * mask.calc_TV_norm(mask_clip, p=3, q=3) if FLAGS.focus_type == 'correct': label_index = tf.reshape(tf.argmax(y, axis=1), []) if FLAGS.focus_type == 'guessed': label_index = tf.reshape(tf.argmax(logits, axis=1), []) class_loss = after_softmax[:, label_index] # Cast as same type as l1 and TV. class_loss = tf.cast(class_loss, tf.float32) loss_function = l1loss + tvnormLoss + class_loss optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate_start) train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'mask') with tf.variable_scope('minimize'): training_op = optimizer.minimize(loss_function, var_list=train_var) sess.run(tf.variables_initializer(optimizer.variables())) for i in range(STEPS_VAL): input_var, label, video_id = sess.run(next_element) # video_id is returned as an array of a bytes object video_id = video_id[0].decode("utf-8") # like so: array([b'74225'], dtype=object) # only look at cases where the network was correct, and is of a certain class (if class of interest 'classoI' was given) current_class = np.argmax(label) print(current_class) current_class = str(np.argmax(label)) if current_class in list(df.keys()) and int(video_id) in [derp for derp in df[current_class]]: print("found clip of interest ", current_class, video_id) preds = sess.run(logits, feed_dict={mask_var: np.zeros((FLAGS.seq_length)), original_input_var: input_var, frame_inds: range(FLAGS.seq_length)}) print('np argmax preds', np.argmax(preds)) # eta is for breaking out of the grad desc early if it hasn't improved eta = 0.00001 have_output = False if preds[:, int(current_class)] < 0.1: print('the guess for the correct class was less than 0.1') continue if not have_output: output = sess.run(after_softmax, feed_dict={mask_var: np.zeros((FLAGS.seq_length)), original_input_var: input_var, frame_inds: range(FLAGS.seq_length)}) have_output = True if run_temp_mask: if mask_type == 'gradient': start_mask = mask.init_mask(input_var, mask_var, original_input_var, frame_inds, after_softmax, sess, label, thresh=0.9, mode="central", mask_pert_type=FLAGS.temporal_mask_type) # Initialize mask variable sess.run(mask_assign, {mask_plhdr: start_mask}) sess.run(original_input_assign, {original_input_plhdr: input_var}) oldLoss = 999999 for nidx in range(N): if (nidx % 10 == 0): print("on nidx: ", nidx) print("mask_clipped is: ", sess.run(mask_clip)) _, loss_value, \ l1value, tvvalue, classlossvalue = sess.run([training_op, loss_function, l1loss, tvnormLoss, class_loss], feed_dict={y: label, frame_inds: range(FLAGS.seq_length), original_input_var: input_var}) print("LOSS: {}, l1loss: {}, tv: {}, class: {}".format(loss_value, l1value, tvvalue, classlossvalue)) if (abs(oldLoss - loss_value) < eta): break; time_mask = sess.run(mask_clip) save_path = os.path.join("cam_saved_images", FLAGS.output_folder, str(np.argmax(label)), video_id + "g_" + \ str(np.argmax(output)) + \ "_cs%5.4f" % output[:, np.argmax(label)] + \ "gs%5.4f" % output[:, np.argmax(output)], "combined") if not os.path.exists(save_path): os.makedirs(save_path) f = open(save_path + "/ClassScoreFreezecase" + video_id + ".txt", "w+") f.write(str(classlossvalue)) f.close() if FLAGS.temporal_mask_type == 'reverse': perturbed_sequence = mask.perturb_sequence(input_var, time_mask, perb_type='reverse') class_loss_rev = sess.run(class_loss, feed_dict={mask_var: np.zeros((FLAGS.seq_length)), original_input_var: perturbed_sequence, frame_inds: range(FLAGS.seq_length)}) f = open(save_path + "/ClassScoreReversecase" + video_id + ".txt", "w+") f.write(str(class_loss_rev)) f.close() if (verbose): print("resulting mask is: ", sess.run(mask_clip)) if (do_gradcam): if (FLAGS.focus_type == "guessed"): target_index = np.argmax(output) if (FLAGS.focus_type == "correct"): target_index = np.argmax(label) gradcam = gc.get_gradcam(sess, logits, clstm_3, y, original_input_var, mask_var, frame_inds, input_var, label, target_index, FLAGS.image_size, FLAGS.image_size) '''beginning of gradcam write to disk''' os.makedirs(save_path, exist_ok=True) if do_gradcam and run_temp_mask: viz.create_image_arrays( input_var, gradcam, time_mask, save_path, video_id, 'freeze', FLAGS.image_size, FLAGS.image_size) if FLAGS.temporal_mask_type == 'reverse': # Also create the image arrays for the reverse operation. viz.create_image_arrays( input_var, gradcam, time_mask, save_path, video_id, 'reverse', FLAGS.image_size, FLAGS.image_size) if run_temp_mask: viz.visualize_results( input_var, mask.perturb_sequence( input_var, time_mask, perb_type='reverse'), time_mask, root_dir=save_path, case=video_id, mark_imgs=True, iter_test=False)