def _construct_feed(self, scaling, enabled_noise=False): if 'embedded_layer' not in self.ig_modal_target_data: feed_dict = construct_feed(self.batch_idx, self.placeholders, self.all_data, info=self.info, config=self.config, scaling=scaling, perturbation_target=self.perturbation_target, enabled_noise=enabled_noise) else: feed_dict = construct_feed(self.batch_idx, self.placeholders, self.all_data, info=self.info, config=self.config, scaling=scaling, perturbation_target=self.perturbation_target, embedded_layer=self.ig_modal_target_data['embedded_layer'], enabled_noise=enabled_noise) return feed_dict
def cal_integrated_gradients(self, sess, placeholders, prediction, divide_number): """ Args: sess: session object placeholders: prediction: prediction score(output of the network) divide_number: division number of a prediction score """ ig_placeholders = placeholders["embedded_layer"] tf_grads = tf.gradients(prediction, ig_placeholders) IGs = OrderedDict() self.logger.debug(f"self.ig_modal_target = {self.ig_modal_target}") for k in self.ig_modal_target: self.logger.info(f"{k}: np.zeros({self.shapes[k]})") IGs[k] = np.zeros(self.shapes[k], dtype=np.float32) for k in range(divide_number): s = time.time() scaling_coef = (k + 1) / float(divide_number) feed_dict = construct_feed(self.batch_idx, self.placeholders, self.all_data, config=self.config, info=self.info, scaling=scaling_coef, perturbation_target=self.perturbation_target) out_grads = sess.run(tf_grads, feed_dict=feed_dict) for idx, modal_name in enumerate(IGs): _target_data = self.ig_modal_target_data[modal_name] IGs[modal_name] += out_grads[idx][0] * _target_data / float(divide_number) self.logger.info(f'[IG] {k:3d}th / {divide_number} : ' f'[TIME] {(time.time() - s):7.4f}s') self.IGs = IGs # If IG is calculated correctly, "total of IG" approximately equal to "difference between the prediction score # with scaling factor = 1 and with scaling factor = 0". self.sum_of_ig = 0 for values in self.IGs.values(): self.sum_of_ig += np.sum(values)
def cal_feature_IG_for_kg(sess, all_data, placeholders, info, config, prediction, model=None, logger=None, verbosity=None, args=None): divide_number = 30 outdir = config["visualize_path"] os.makedirs(outdir, exist_ok=True) batch_idx = [0, ] # assume batch size is only one. feed_dict = construct_feed(batch_idx, placeholders, all_data, config=config, batch_size=1, info=info) if 'visualize_target' not in config.keys(): raise ValueError('set "visualize_target" in your config.') logger.info(f"config['visualize_target'] = {config['visualize_target']}") if config['visualize_target'] is None: n_samples = all_data.label_list.shape[1] if 'edge' in config['visualize_type'] else prediction.shape[1] logger.info(f'visualization targets are all.') logger.info(f'n_samples = {n_samples}') targets = range(n_samples) else: targets = [config['visualize_target'], ] for target in targets: if 'edge' in config['visualize_type']: if config['visualize_type'] == 'edge_score': _prediction = model.score[target] elif config['visualize_type'] == 'edge_loss': _prediction = model.loss[target] else: print("[ERROR]") sys.exit(1) node1 = all_data.label_list[0, target, 0] node2 = all_data.label_list[0, target, 1] logger.info(f"edge target = {target} => {node1}-{node2}") filename = f'edgepred-{node1}-{node2}' vis_nodes = [node1, node2] else: # for node visualization out_prediction = sess.run(prediction, feed_dict=feed_dict) target_index = np.argmax(out_prediction[:, target, :]) _prediction = prediction[:, target, target_index] logger.info(f"target_index = {target_index}") filename = f'nodepred-{target}' vis_nodes = [target, ] visualizer = KnowledgeGraphVisualizer(outdir, info, config, batch_idx, placeholders, all_data, _prediction, logger=logger, model=model) visualizer.cal_integrated_gradients(sess, placeholders, _prediction, divide_number) visualizer.dump(filename, vis_nodes)
def cal_feature_IG(sess, all_data, placeholders, info, config, prediction, ig_modal_target, ig_label_target, *, model=None, logger=None, args=None): """ calculate integrated gradients Args: sess: session object all_data: placeholders: info: config prediction: prediction score(output of the network) ig_modal_target: ig_label_target: model: logger: args: """ divide_number = 100 header = "mol" if args is not None and args.visualization_header is not None: header = args.visualization_header outdir = config["visualize_path"] os.makedirs(outdir, exist_ok=True) mol_obj_list = info.mol_info["obj_list"] if "mol_info" in info else None tf_grads = None all_count = 0 correct_count = 0 visualize_ids = range(all_data.num) if args.visualize_resample_num: visualize_ids = np.random.choice(visualize_ids, args.visualize_resample_num, replace=False) for compound_id in visualize_ids: s = time.time() batch_idx = [compound_id] if all_data['sequences'] is not None and hasattr(model, "embedding"): _data = all_data['sequences'] _data = np.expand_dims(_data[compound_id, ...], axis=0) _data = model.embedding(sess, _data) feed_dict = construct_feed(batch_idx, placeholders, all_data, batch_size=1, info=info, embedded_layer=_data) else: feed_dict = construct_feed(batch_idx, placeholders, all_data, batch_size=1, info=info) out_prediction = sess.run(prediction, feed_dict=feed_dict) # print("prediction shape",out_prediction.shape) # to give consistency with multitask. multitask = False if len(out_prediction.shape) == 1: out_prediction = out_prediction[:, np.newaxis, np.newaxis] elif len(out_prediction.shape) == 2: out_prediction = np.expand_dims(out_prediction, axis=1) elif len(out_prediction.shape) == 3: if out_prediction.shape[1] > 1: multitask = True # out_prediction: #data x # task x #class # labels: data x #task/#label for idx in range(out_prediction.shape[1]): _out_prediction = out_prediction[0, idx, :] true_label = np.argmax( all_data.labels[compound_id] ) if not multitask else all_data.labels[compound_id, idx] # convert a assay string according to a prediction score if len(_out_prediction) > 2: # softmax output assay_str = f"class{np.argmax(_out_prediction)}" elif len(_out_prediction) == 2: # softmax output assay_str = "active" if _out_prediction[1] > 0.5 else "inactive" else: assay_str = "active" if _out_prediction > 0.5 else "inactive" _prediction = prediction[:, idx, :] if len( prediction.shape) == 3 else prediction # multitask = 3 if ig_label_target == "max": target_index = np.argmax(_out_prediction) target_prediction = _prediction[:, target_index] target_score = _out_prediction[target_index] elif ig_label_target == "all": target_prediction = _prediction target_index = "all" target_score = np.sum(_out_prediction) elif ig_label_target == "correct": target_index = np.argmax(_out_prediction) if not target_index == true_label: continue target_prediction = _prediction[:, target_index] target_score = _out_prediction[target_index] elif ig_label_target == "uncorrect": target_index = np.argmax(_out_prediction) if target_index == true_label: continue target_prediction = _prediction[:, target_index] target_score = _out_prediction[target_index] elif ig_label_target == "label": target_index = true_label target_prediction = _prediction[:, target_index] target_score = _out_prediction[target_index] else: target_index = int(ig_label_target) target_prediction = _prediction[:, target_index] target_score = _out_prediction[target_index] try: mol_name = Chem.MolToSmiles(mol_obj_list[compound_id]) mol_obj = mol_obj_list[compound_id] except: mol_name = None mol_obj = None if args.verbose: print( f"No.{compound_id}, task={idx}: \"{mol_name}\": {assay_str} (score= {_out_prediction}, " f"true_label= {true_label}, target_label= {target_index}, target_score= {target_score})" ) else: print( f"No.{compound_id}, task={idx}: \"{mol_name}\": {assay_str}" ) visualizer = CompoundVisualizer( sess, outdir, compound_id, info, config, batch_idx, placeholders, all_data, target_prediction, logger=logger, model=model, ig_modal_target=ig_modal_target, perturbation_target=ig_modal_target, grads=tf_grads) tf_grads = visualizer.grads if tf_grads is None else tf_grads visualizer.cal_integrated_gradients(sess, divide_number, method=args.visualize_method) visualizer.check_IG(sess, target_prediction) visualizer.dump( f"{header}_{compound_id:04d}_task_{idx}_{assay_str}_{ig_modal_target}_scaling.jbl", additional_data={ "mol": mol_obj, "prediction_score": target_score, "target_label": target_index, "true_label": true_label, }) logger.info( f"prediction score: {target_score}\n" f"check score: {visualizer.end_score - visualizer.start_score}\n" f"sum of IG: {visualizer.sum_of_ig}\n" f"time : {time.time() - s}\n") all_count += 1 if np.argmax(_out_prediction) == int(true_label): correct_count += 1 logger.info(f"accuracy(visualized_data) = {correct_count/all_count}")