def get_link_pred_perfs_by_attention(model, edge_y, layer_idx=-1, metric="roc_auc"): """ :param model: GNN model (nn.Module) :param edge_y: [E_pred] tensor :param layer_idx: layer idx of GNN models :param metric: metric for perfs :return: """ cache_list = [ m.cache for m in model.modules() if m.__class__.__name__ == SuperGAT.__name__ ] cache_of_layer_idx = cache_list[layer_idx] att = cache_of_layer_idx["att_with_negatives"] # [E + neg_E, heads] att = att.mean(dim=-1) # [E + neg_E] edge_probs, edge_y = np_sigmoid( att.cpu().numpy()), edge_y.cpu().numpy() perfs = None if metric == "roc_auc": perfs = roc_auc_score(edge_y, edge_probs) elif metric == "average_precision": perfs = average_precision_score(edge_y, edge_probs) elif metric == "accuracy": perfs = accuracy(edge_probs, edge_y) else: ValueError("Inappropriate metric: {}".format(metric)) return perfs
def eval_pred_one_epoch(sess, opt_dict, smi_list, label_list): num = 0 loss_total = 0.0 y_truth_total = np.empty([0,]) y_pred_total = np.empty([0,]) num_batches = len(smi_list) // FLAGS.batch_size if(len(smi_list)%FLAGS.batch_size != 0): num_batches += 1 for i in range(num_batches): num += 1 st_i = time.time() adj, x, y = preprocess_inputs(smi_list[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size], label_list[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size], FLAGS.num_max_atoms) feed_dict = {opt_dict.x:x, opt_dict.adj:adj, opt_dict.y:y, opt_dict.is_training:False} y_pred, loss = sess.run( [opt_dict.logits, opt_dict.pred_loss], feed_dict=feed_dict) y_pred = np_sigmoid(y_pred[:,0]) loss_total += loss et_i = time.time() y_truth_total = np.concatenate((y_truth_total, y), axis=0) y_pred_total = np.concatenate((y_pred_total, y_pred), axis=0) loss_total /= num return loss_total, y_truth_total, y_pred_total
def ibm_hat(sess, net, wav, wav_len, args): x_MS_3D_out = sess.run(net.x_MS_3D, feed_dict={ net.x_ph: wav, net.x_len_ph: wav_len }) x_seq_len_out = sess.run(net.x_seq_len, feed_dict={net.x_len_ph: wav_len}) mu_np = sess.run(net.mu) # mean tf.constant to np.array. sigma_np = sess.run(net.sigma) # standard deviation tf.constant to np.array. output_out = sess.run(net.output, feed_dict={ net.x_MS_ph: x_MS_3D_out, net.x_MS_len_ph: x_seq_len_out, net.training_ph: False }) # output of network. output_out = utils.np_sigmoid(output_out) xi_dB_hat_out = np.add( np.multiply(np.multiply(sigma_np, np.sqrt(2.0)), spsp.erfinv(np.subtract(np.multiply(2.0, output_out), 1))), mu_np) # a priori SNR estimate. xi_hat_out = np.power(10.0, np.divide(xi_dB_hat_out, 10.0)) return np.greater( np.matmul(xi_hat_out[0:-1, :], np.transpose(args.H_tapered)), 1.0)
def train_pred_one_epoch(sess, opt_dict, smi_list, label_list): num = 0 loss_total = 0.0 y_truth_total = np.empty([0,]) y_pred_total = np.empty([0,]) num_batches = len(smi_list) // FLAGS.batch_size if(len(smi_list)%FLAGS.batch_size != 0): num_batches += 1 for i in range(num_batches): num += 1 st_i = time.time() adj, x, y = preprocess_inputs(smi_list[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size], label_list[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size], FLAGS.num_max_atoms) feed_dict = {opt_dict.x:x, opt_dict.adj:adj, opt_dict.y:y, opt_dict.is_training:True} operations = [] if FLAGS.optimize=='fully': operations.append(opt_dict.train_fully) elif FLAGS.optimize=='predictor': operations.append(opt_dict.train_pred) operations.append(opt_dict.logits) operations.append(opt_dict.pred_loss) _, y_pred, loss = sess.run(operations, feed_dict) y_pred = np_sigmoid(y_pred[:,0]) loss_total += loss et_i = time.time() print ("Train_iter : ", num, \ ", loss : ", loss, \ "\t Time:", round(et_i-st_i,3)) y_truth_total = np.concatenate((y_truth_total, y), axis=0) y_pred_total = np.concatenate((y_pred_total, y_pred), axis=0) loss_total /= num return loss_total, y_truth_total, y_pred_total
def infer(sess, net, args): print("Inference...") ## LOAD MODEL net.saver.restore(sess, args.model_path + '/epoch-' + str(args.epoch)) # load model from epoch. ## CONVERT STATISTIC CONSTANTS TO NUMPY ARRAY mu_np = sess.run(net.mu) # place mean constant into a numpy array. sigma_np = sess.run(net.sigma) # place standard deviation constant into a numpy array. for j in range(len(args.test_x_len)): x_MS_out = sess.run(net.x_MS, feed_dict={ net.x_ph: [args.test_x[j]], net.x_len_ph: [args.test_x_len[j]] }) x_MS_3D_out = sess.run(net.x_MS_3D, feed_dict={ net.x_ph: [args.test_x[j]], net.x_len_ph: [args.test_x_len[j]] }) x_PS_out = sess.run(net.x_PS, feed_dict={ net.x_ph: [args.test_x[j]], net.x_len_ph: [args.test_x_len[j]] }) x_seq_len_out = sess.run( net.x_seq_len, feed_dict={net.x_len_ph: [args.test_x_len[j]]}) output_out = sess.run(net.output, feed_dict={ net.x_MS_ph: x_MS_3D_out, net.x_MS_len_ph: x_seq_len_out, net.training_ph: False }) # output of network. output_out = utils.np_sigmoid(output_out) xi_dB_hat_out = np.add( np.multiply( np.multiply(sigma_np, np.sqrt(2.0)), spsp.erfinv(np.subtract(np.multiply(2.0, output_out), 1))), mu_np) # a priori SNR estimate. xi_hat_out = np.power(10.0, np.divide(xi_dB_hat_out, 10.0)) if args.gain == 'mmse-stsa': gain_out = feat.mmse_stsa( xi_hat_out, feat.ml_gamma_hat(xi_hat_out)) # MMSE-STSA estimator gain. elif args.gain == 'mmse-lsa': gain_out = feat.mmse_lsa( xi_hat_out, feat.ml_gamma_hat(xi_hat_out)) # MMSE-LSA estimator gain. else: gain_out = sess.run(net.G, feed_dict={net.xi_hat_ph: xi_hat_out}) # gain. if args.out_type == 'raw': # raw outputs from network (.mat). if not os.path.exists(args.out_path + '/raw'): os.makedirs(args.out_path + '/raw') # make output directory. spio.savemat( args.out_path + '/raw/' + args.test_fnames[j] + '.mat', {'raw': output_out}) if args.out_type == 'xi_hat': # a priori SNR estimate output (.mat). if not os.path.exists(args.out_path + '/xi_hat'): os.makedirs(args.out_path + '/xi_hat') # make output directory. spio.savemat( args.out_path + '/xi_hat/' + args.test_fnames[j] + '.mat', {'xi_hat': xi_hat_out}) if args.out_type == 'gain': # gain function output (.mat). if not os.path.exists(args.out_path + '/gain/' + gain): os.makedirs(args.out_path + '/gain/' + args.gain) # make output directory. spio.savemat( args.out_path + '/gain/' + args.gain + '/' + args.test_fnames[j] + '.mat', {gain: gain_out}) if args.out_type == 'y': # enahnced speech output (.wav). if not os.path.exists(args.out_path + '/y/' + args.gain): os.makedirs(args.out_path + '/y/' + args.gain) # make output directory. y_out = sess.run(net.y, feed_dict={ net.G_ph: gain_out, net.x_PS_ph: x_PS_out, net.x_MS_2D_ph: x_MS_out, net.output_ph: output_out }) # enhanced speech output. scipy.io.wavfile.write( args.out_path + '/y/' + args.gain + '/' + args.test_fnames[j] + '.wav', args.fs, y_out) print("Inference (%s): %3.2f%%. " % (args.out_type, 100 * ((j + 1) / len(args.test_x_len))), end="\r") print('\nInference complete.')
def train(opt_dict): train_set, valid_set, test_set = get_mnist_data() x_train, y_train = instances_to_bags(ds=train_set, n_inst=FLAGS.n_inst, target=FLAGS.target, n_bags=FLAGS.n_bags, p=FLAGS.prob_target) x_test, y_test = instances_to_bags(ds=test_set, n_inst=FLAGS.n_inst, target=FLAGS.target, n_bags=1000, p=FLAGS.prob_target) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in range(FLAGS.n_epoches): # Train true_train = np.empty([ 0, ]) pred_train = np.empty([ 0, ]) for i in range(x_train.shape[0]): st_i = time.time() xi = x_train[i] yi = np.asarray([y_train[i]]) feed_dict = {opt_dict.x: xi, opt_dict.y: yi} ops = [opt_dict.train_op, opt_dict.loss, opt_dict.logits] _, loss, logits = sess.run(ops, feed_dict=feed_dict) et_i = time.time() print ("Training", epoch, "-th epoch\t", \ i, "-th bag\t Loss=", loss, "\t Time:", round(et_i-st_i,3), "(s)") true_train = np.concatenate([true_train, yi], axis=0) pred_train = np.concatenate( [pred_train, np_sigmoid(logits)], axis=0) print_metrics(true_train, pred_train) # Test true_test = np.empty([ 0, ]) pred_test = np.empty([ 0, ]) for i in range(x_train.shape[0]): st_i = time.time() xi = x_test[i] yi = np.asarray([y_test[i]]) feed_dict = {opt_dict.x: xi, opt_dict.y: yi} ops = [opt_dict.loss, opt_dict.logits] loss, logits = sess.run(ops, feed_dict=feed_dict) et_i = time.time() true_test = np.concatenate([true_test, yi], axis=0) pred_test = np.concatenate( [pred_test, np_sigmoid(logits)], axis=0) print_metrics(true_test, pred_test) print("Finish training and test") return