def extract_reliable_negatives_fuselier(model, data_stuff, negative_idx, th=0.5, yshape=2, N_inputs=1): negative_stuff = slice_data(data_stuff, negative_idx, N_inputs) drugability = model.predict(negative_stuff[0]) if yshape == 2: drugability = drugability[:, 1] new_reliable_negatives_internal = np.where(drugability < th)[0] new_reliable_negatives = negative_idx[new_reliable_negatives_internal] return new_reliable_negatives, make_stats_from_vector( drugability), np.where(drugability > 0.5, 1, 0).sum()
def masked_balanced_score(y_pred, y_true, weights, balance=True, auc=False): w = np.where(weights>0)[0] if balance: function = balanced_accuracy_score else: function = accuracy_score if auc: function = roc_auc_score to_use_pred = y_pred[w] to_use_true = y_true[w] return function(to_use_true, to_use_pred)
def get_new_weights(y_all, reliable_negatives): binary_elements_condition = set(y_all.reshape(-1)) in [{0, 1}, {0.0, 1.0}] two_class_condition = len(y_all.shape) <= 2 message = 'Y should be binary, y.shape=%s' % str(y_all.shape) assert binary_elements_condition and two_class_condition, message if len(y_all.shape) == 2: y_to_use = y.argmax(axis=1) else: y_to_use = y_all N_positive = y_to_use.sum() N_negative = len(reliable_negatives) N_tot = float(N_positive + N_negative) w_pos = N_negative / N_tot w_neg = N_positive / N_tot weights = np.where(y_to_use == 1, w_pos, w_neg) return weights
#======= loading stuff ========== logger.info('Start') config = load_yaml(args.config) if args.data_checkpoint is None: args.data_checkpoint = args.output_core + '_data_chk.npz' config['loader_config']['data_checkpoint'] = args.data_checkpoint config['training_cfg'] = dict(batch_size=100, epochs=args.epochs, shuffle=True, verbose=0) x, y, weights = load_from_config(config, args.reload) #!!!! QUICK FIX!!!! x = np.nan_to_num(x) x = np.where(abs(x) < 10.0, x, 0) data_stuff = [x, y, weights] if len(y.shape) == 2: assert y.shape[1] == 2, "Y should be binary, y.shape=%s" % str(y.shape) positive_idx = np.where(y.argmax(axis=1) == 1)[0] unlabelled_idx = np.where(y.argmax(axis=1) == 0)[0] negative_idx = np.where(y.argmax(axis=1) == 0)[0] elif len(y.shape) == 1: positive_idx = np.where(y == 1)[0] unlabelled_idx = np.where(y == 0)[0] negative_idx = np.where(y == 0)[0] else: raise ValueError('Incorrect shape, y.shape=' % str(y.shape))