Esempio n. 1
0
 def _construct_feed(self, scaling, enabled_noise=False):
     if 'embedded_layer' not in self.ig_modal_target_data:
         feed_dict = construct_feed(self.batch_idx, self.placeholders, self.all_data, info=self.info,
                                    config=self.config, scaling=scaling, perturbation_target=self.perturbation_target,
                                    enabled_noise=enabled_noise)
     else:
         feed_dict = construct_feed(self.batch_idx, self.placeholders, self.all_data, info=self.info,
                                    config=self.config, scaling=scaling, perturbation_target=self.perturbation_target,
                                    embedded_layer=self.ig_modal_target_data['embedded_layer'],
                                    enabled_noise=enabled_noise)
     return feed_dict
Esempio n. 2
0
    def cal_integrated_gradients(self, sess, placeholders, prediction, divide_number):
        """
        Args:
            sess: session object
            placeholders:
            prediction: prediction score(output of the network)
            divide_number: division number of a prediction score
        """
        ig_placeholders = placeholders["embedded_layer"]
        tf_grads = tf.gradients(prediction, ig_placeholders)
        IGs = OrderedDict()

        self.logger.debug(f"self.ig_modal_target = {self.ig_modal_target}")
        for k in self.ig_modal_target:
            self.logger.info(f"{k}: np.zeros({self.shapes[k]})")
            IGs[k] = np.zeros(self.shapes[k], dtype=np.float32)
        for k in range(divide_number):
            s = time.time()
            scaling_coef = (k + 1) / float(divide_number)
            feed_dict = construct_feed(self.batch_idx, self.placeholders, self.all_data,
                                       config=self.config, info=self.info, scaling=scaling_coef,
                                       perturbation_target=self.perturbation_target)
            out_grads = sess.run(tf_grads, feed_dict=feed_dict)
            for idx, modal_name in enumerate(IGs):
                _target_data = self.ig_modal_target_data[modal_name]
                IGs[modal_name] += out_grads[idx][0] * _target_data / float(divide_number)
            self.logger.info(f'[IG] {k:3d}th / {divide_number} : '
                             f'[TIME] {(time.time() - s):7.4f}s')
        self.IGs = IGs

        # If IG is calculated correctly, "total of IG" approximately equal to "difference between the prediction score
        # with scaling factor = 1 and with scaling factor = 0".
        self.sum_of_ig = 0
        for values in self.IGs.values():
            self.sum_of_ig += np.sum(values)
Esempio n. 3
0
def cal_feature_IG_for_kg(sess, all_data, placeholders, info, config, prediction,
                          model=None, logger=None, verbosity=None, args=None):
    divide_number = 30
    outdir = config["visualize_path"]
    os.makedirs(outdir, exist_ok=True)
    batch_idx = [0, ]  # assume batch size is only one.
    feed_dict = construct_feed(batch_idx, placeholders, all_data, config=config, batch_size=1, info=info)

    if 'visualize_target' not in config.keys():
        raise ValueError('set "visualize_target" in your config.')
    logger.info(f"config['visualize_target'] = {config['visualize_target']}")

    if config['visualize_target'] is None:
        n_samples = all_data.label_list.shape[1] if 'edge' in config['visualize_type'] else prediction.shape[1]
        logger.info(f'visualization targets are all.')
        logger.info(f'n_samples = {n_samples}')
        targets = range(n_samples)
    else:
        targets = [config['visualize_target'], ]

    for target in targets:
        if 'edge' in config['visualize_type']:
            if config['visualize_type'] == 'edge_score':
                _prediction = model.score[target]
            elif config['visualize_type'] == 'edge_loss':
                _prediction = model.loss[target]
            else:
                print("[ERROR]")
                sys.exit(1)
            node1 = all_data.label_list[0, target, 0]
            node2 = all_data.label_list[0, target, 1]
            logger.info(f"edge target = {target} => {node1}-{node2}")
            filename = f'edgepred-{node1}-{node2}'
            vis_nodes = [node1, node2]
        else:
            # for node visualization
            out_prediction = sess.run(prediction, feed_dict=feed_dict)
            target_index = np.argmax(out_prediction[:, target, :])
            _prediction = prediction[:, target, target_index]
            logger.info(f"target_index = {target_index}")
            filename = f'nodepred-{target}'
            vis_nodes = [target, ]

        visualizer = KnowledgeGraphVisualizer(outdir, info, config, batch_idx, placeholders, all_data, _prediction,
                                              logger=logger, model=model)
        visualizer.cal_integrated_gradients(sess, placeholders, _prediction, divide_number)
        visualizer.dump(filename, vis_nodes)
Esempio n. 4
0
def cal_feature_IG(sess,
                   all_data,
                   placeholders,
                   info,
                   config,
                   prediction,
                   ig_modal_target,
                   ig_label_target,
                   *,
                   model=None,
                   logger=None,
                   args=None):
    """ calculate integrated gradients
    Args:
        sess: session object
        all_data:
        placeholders:
        info:
        config
        prediction: prediction score(output of the network)
        ig_modal_target:
        ig_label_target:
        model:
        logger:
        args:
    """
    divide_number = 100
    header = "mol"
    if args is not None and args.visualization_header is not None:
        header = args.visualization_header
    outdir = config["visualize_path"]
    os.makedirs(outdir, exist_ok=True)
    mol_obj_list = info.mol_info["obj_list"] if "mol_info" in info else None
    tf_grads = None

    all_count = 0
    correct_count = 0
    visualize_ids = range(all_data.num)
    if args.visualize_resample_num:
        visualize_ids = np.random.choice(visualize_ids,
                                         args.visualize_resample_num,
                                         replace=False)
    for compound_id in visualize_ids:
        s = time.time()
        batch_idx = [compound_id]
        if all_data['sequences'] is not None and hasattr(model, "embedding"):
            _data = all_data['sequences']
            _data = np.expand_dims(_data[compound_id, ...], axis=0)
            _data = model.embedding(sess, _data)
            feed_dict = construct_feed(batch_idx,
                                       placeholders,
                                       all_data,
                                       batch_size=1,
                                       info=info,
                                       embedded_layer=_data)
        else:
            feed_dict = construct_feed(batch_idx,
                                       placeholders,
                                       all_data,
                                       batch_size=1,
                                       info=info)

        out_prediction = sess.run(prediction, feed_dict=feed_dict)
        # print("prediction shape",out_prediction.shape)
        # to give consistency with multitask.
        multitask = False
        if len(out_prediction.shape) == 1:
            out_prediction = out_prediction[:, np.newaxis, np.newaxis]
        elif len(out_prediction.shape) == 2:
            out_prediction = np.expand_dims(out_prediction, axis=1)
        elif len(out_prediction.shape) == 3:
            if out_prediction.shape[1] > 1:
                multitask = True
        # out_prediction: #data x # task x #class
        # labels: data x #task/#label
        for idx in range(out_prediction.shape[1]):
            _out_prediction = out_prediction[0, idx, :]
            true_label = np.argmax(
                all_data.labels[compound_id]
            ) if not multitask else all_data.labels[compound_id, idx]
            # convert a assay string according to a prediction score
            if len(_out_prediction) > 2:  # softmax output
                assay_str = f"class{np.argmax(_out_prediction)}"
            elif len(_out_prediction) == 2:  # softmax output
                assay_str = "active" if _out_prediction[1] > 0.5 else "inactive"
            else:
                assay_str = "active" if _out_prediction > 0.5 else "inactive"
            _prediction = prediction[:, idx, :] if len(
                prediction.shape) == 3 else prediction  # multitask = 3

            if ig_label_target == "max":
                target_index = np.argmax(_out_prediction)
                target_prediction = _prediction[:, target_index]
                target_score = _out_prediction[target_index]
            elif ig_label_target == "all":
                target_prediction = _prediction
                target_index = "all"
                target_score = np.sum(_out_prediction)
            elif ig_label_target == "correct":
                target_index = np.argmax(_out_prediction)
                if not target_index == true_label:
                    continue
                target_prediction = _prediction[:, target_index]
                target_score = _out_prediction[target_index]
            elif ig_label_target == "uncorrect":
                target_index = np.argmax(_out_prediction)
                if target_index == true_label:
                    continue
                target_prediction = _prediction[:, target_index]
                target_score = _out_prediction[target_index]
            elif ig_label_target == "label":
                target_index = true_label
                target_prediction = _prediction[:, target_index]
                target_score = _out_prediction[target_index]
            else:
                target_index = int(ig_label_target)
                target_prediction = _prediction[:, target_index]
                target_score = _out_prediction[target_index]

            try:
                mol_name = Chem.MolToSmiles(mol_obj_list[compound_id])
                mol_obj = mol_obj_list[compound_id]
            except:
                mol_name = None
                mol_obj = None
            if args.verbose:
                print(
                    f"No.{compound_id}, task={idx}: \"{mol_name}\": {assay_str} (score= {_out_prediction}, "
                    f"true_label= {true_label}, target_label= {target_index}, target_score= {target_score})"
                )
            else:
                print(
                    f"No.{compound_id}, task={idx}: \"{mol_name}\": {assay_str}"
                )
            visualizer = CompoundVisualizer(
                sess,
                outdir,
                compound_id,
                info,
                config,
                batch_idx,
                placeholders,
                all_data,
                target_prediction,
                logger=logger,
                model=model,
                ig_modal_target=ig_modal_target,
                perturbation_target=ig_modal_target,
                grads=tf_grads)
            tf_grads = visualizer.grads if tf_grads is None else tf_grads
            visualizer.cal_integrated_gradients(sess,
                                                divide_number,
                                                method=args.visualize_method)
            visualizer.check_IG(sess, target_prediction)
            visualizer.dump(
                f"{header}_{compound_id:04d}_task_{idx}_{assay_str}_{ig_modal_target}_scaling.jbl",
                additional_data={
                    "mol": mol_obj,
                    "prediction_score": target_score,
                    "target_label": target_index,
                    "true_label": true_label,
                })
            logger.info(
                f"prediction score: {target_score}\n"
                f"check score: {visualizer.end_score - visualizer.start_score}\n"
                f"sum of IG: {visualizer.sum_of_ig}\n"
                f"time : {time.time() - s}\n")
            all_count += 1
            if np.argmax(_out_prediction) == int(true_label):
                correct_count += 1
    logger.info(f"accuracy(visualized_data) = {correct_count/all_count}")