def test(): mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) with tf.Session() as sess: x = tf.placeholder(tf.float32, [FLAGS.batch_size, 784], name='input') with tf.variable_scope('model'): my_netowrk = layers() output = my_netowrk.forward(x) if FLAGS.relevance: RELEVANCE = my_netowrk.lrp(output, 'simple', 1.0) # Merge all the summaries and write them out merged = tf.summary.merge_all() test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/my_model') # Intialize variables and reload your model saver = init_vars(sess) # Extract testing data xs, ys = mnist.test.next_batch(FLAGS.batch_size) # Pass the test data to the restored model summary, relevance_test = sess.run([merged, RELEVANCE], feed_dict={x: (2 * xs) - 1}) test_writer.add_summary(summary, 0) # Save the images as heatmaps to visualize on tensorboard images = xs.reshape([FLAGS.batch_size, 28, 28, 1]) images = (images + 1) / 2.0 relevances = relevance_test.reshape([FLAGS.batch_size, 28, 28, 1]) plot_relevances(relevances, images, test_writer) test_writer.close()
def train(): file_list = glob('./test_data/tmp_image/' + "*.jpg") gt = open('./test_data/val.txt','r').readlines() gt_list = [] gt_num = [] for i in range(len(gt)): tmp = gt[i].split() gt_list.append(tmp[0][:-5]) gt_num.append(int(tmp[1])) img = [] img_name = [] gt_real = [] for i in range(len(file_list)): if img == []: img = utils_vgg.load_image(file_list[i]) img = np.expand_dims(img,0) # img_name.append(file_list[i].split('/')[-1]) # ubuntu img_name.append(file_list[i].split('\\')[-1]) # windows gt_real.append(gt_num[np.where(np.array(gt_list[:])==img_name[i][:-4])[0][0]]) else: tmp = np.expand_dims(utils_vgg.load_image(file_list[i]),0) if tmp.shape[1:]==(224,224,3): img = np.concatenate([img,tmp],0) # img_name.append(file_list[i].split('/')[-1]) # ubuntu img_name.append(file_list[i].split('\\')[-1]) # windows gt_real.append(gt_num[np.where(np.array(gt_list[:])==img_name[i][:-4])[0][0]]) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: with tf.name_scope('input'): x = tf.placeholder(tf.float32, shape=[None, 224, 224, 3]) y_ = tf.placeholder(tf.float32, shape=[None, 1000]) phase = tf.placeholder(tf.bool, name='phase') with tf.variable_scope('model'): net = nn() inp = tf.reshape(x, [FLAGS.batch_size, 224, 224, 3]) rgb_scaled = inp * 255.0 VGG_MEAN = [103.939, 116.779, 123.68] # Convert RGB to BGR red, green, blue = tf.split(num_or_size_splits=3, value=rgb_scaled, axis=3) assert red.get_shape().as_list()[1:] == [224, 224, 1] assert green.get_shape().as_list()[1:] == [224, 224, 1] assert blue.get_shape().as_list()[1:] == [224, 224, 1] bgr = tf.concat(values=[ blue - VGG_MEAN[0], green - VGG_MEAN[1], red - VGG_MEAN[2], ], axis=3) assert bgr.get_shape().as_list()[1:] == [224, 224, 3] op = net.forward(bgr) y = tf.reshape(op, [FLAGS.batch_size, 1000]) soft = tf.nn.softmax(y) with tf.variable_scope('relevance'): if FLAGS.relevance: # q = tf.ones_like(soft) # one_hot = q*y_ # kk = tf.one_hot(tf.argmax(tf.nn.softmax(y),-1),1000) mm = y_ # gt # mm = kk # pred if FLAGS.relevance_method == 'RAP': LRP = [] RAP_pos, RAP_neg = net.RAP(y * mm, y * mm) relevance_layerwise = [] relevance_layerwise_pos = [] relevance_layerwise_neg = [] R_p = y * mm R_n = y * mm for layer in net.modules[::-1]: R_p, R_n = net.RAP_layerwise(layer, R_p, R_n) relevance_layerwise_pos.append(R_p) relevance_layerwise_neg.append(R_n) else: RAP_pos = [] RAP_neg = [] LRP = net.RAP(y * mm, FLAGS.relevance_method) relevance_layerwise = [] relevance_layerwise_pos = [] relevance_layerwise_neg = [] R = y * mm for layer in net.modules[::-1]: R = net.lrp_layerwise(layer, R, FLAGS.relevance_method) relevance_layerwise.append(R) else: LRP = [] relevance_layerwise = [] relevance_layerwise_pos = [] relevance_layerwise_neg = [] RAP_pos = [] RAP_neg = [] with tf.name_scope('accuracy'): accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32)) tf.summary.scalar('accuracy', accuracy) merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test') # tf.global_variables_initializer().run() (x_test, y_test) = (img,gt_real) y_test_one_hot = tf.one_hot(y_test, 1000) for i in range(int(len(file_list)/FLAGS.batch_size)): d = next_batch(FLAGS.batch_size, x_test, y_test_one_hot.eval()) test_inp = {x: d[0], y_: d[1], phase: False} # pdb.set_trace() summary, acc, relevance_test, RAP_p, RAP_n, op2, soft2, rel_layer, rel_layer_rap_p, rel_layer_rap_n = sess.run([merged, accuracy, LRP, RAP_pos, RAP_neg, y, soft, relevance_layerwise, relevance_layerwise_pos, relevance_layerwise_neg], feed_dict=test_inp) test_writer.add_summary(summary, i) print('Accuracy at step %s: %f' % (i, acc)) ak = 0 for m in range(FLAGS.batch_size): ax = np.argmax(soft2[m,:]) print (op2[m,ax], soft2[m,ax], ax, d[1][m,ax]) ak = ak+op2[m,ax] utils_vgg.print_prob(soft2[m], './synset.txt') if FLAGS.relevance_method == 'RAP': vis = RAP_p + RAP_n print([np.sum(rel) for rel in rel_layer_rap_p]) print([np.sum(rel) for rel in rel_layer_rap_n]) else: vis = relevance_test print([np.sum(rel) for rel in rel_layer]) if FLAGS.relevance: # pdb.set_trace() # plot test images with relevances overlaid images = d[0].reshape([FLAGS.batch_size, 224, 224, 3]) # images = (images + 1)/2.0 plot_relevances(vis.reshape([FLAGS.batch_size, 224, 224, 3]), images, test_writer) train_writer.close() test_writer.close()
def train(): # Import data mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) with tf.Session() as sess: # Input placeholders with tf.name_scope('input'): x = tf.placeholder(tf.float32, [FLAGS.batch_size, 784], name='x-input') y_ = tf.placeholder(tf.float32, [FLAGS.batch_size, 10], name='y-input') keep_prob = tf.placeholder(tf.float32) # Model definition along with training and relevances with tf.variable_scope('model'): net = nn() y = net.forward(x) with tf.variable_scope('relevance'): if FLAGS.relevance: LRP = net.lrp(y, FLAGS.relevance_method, 1e-8) # LRP layerwise relevance_layerwise = [] # R = y # for layer in net.modules[::-1]: # R = net.lrp_layerwise(layer, R, 'simple') # relevance_layerwise.append(R) else: LRP = [] relevance_layerwise = [] # Accuracy computation with tf.name_scope('correct_prediction'): correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('accuracy', accuracy) # Merge all the summaries and write them out merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test') tf.global_variables_initializer().run() utils = Utils(sess, FLAGS.checkpoint_dir) if FLAGS.reload_model: utils.reload_model() trainer = net.fit(output=y, ground_truth=y_, loss='softmax_crossentropy', optimizer='adam', opt_params=[FLAGS.learning_rate]) uninit_vars = set(tf.global_variables()) - set( tf.trainable_variables()) tf.variables_initializer(uninit_vars).run() # iterate over train and test data for i in range(FLAGS.max_steps): if i % FLAGS.test_every == 0: #pdb.set_trace() d = feed_dict(mnist, False) test_inp = {x: d[0], y_: d[1], keep_prob: d[2]} summary, acc, relevance_test, op, rel_layer = sess.run( [merged, accuracy, LRP, y, relevance_layerwise], feed_dict=test_inp) test_writer.add_summary(summary, i) print('Accuracy at step %s: %f' % (i, acc)) else: d = feed_dict(mnist, True) inp = {x: d[0], y_: d[1], keep_prob: d[2]} summary, _, relevance_train, op, rel_layer = sess.run( [merged, trainer.train, LRP, y, relevance_layerwise], feed_dict=inp) train_writer.add_summary(summary, i) # relevances plotted with visually pleasing color schemes if FLAGS.relevance: # plot test images with relevances overlaid images = test_inp[test_inp.keys()[0]].reshape( [FLAGS.batch_size, 28, 28, 1]) images = (images + 1) / 2.0 plot_relevances( relevance_test.reshape([FLAGS.batch_size, 28, 28, 1]), images, test_writer) # plot train images with relevances overlaid # images = inp[inp.keys()[0]].reshape([FLAGS.batch_size,28,28,1]) # images = (images + 1)/2.0 # plot_relevances(relevance_train.reshape([FLAGS.batch_size,28,28,1]), images, train_writer ) train_writer.close() test_writer.close()
def train(): # Import data # train_file_path = str(FLAGS.image_dim)+"_train_y.csv" # test_file_path = str(FLAGS.image_dim)+"_test_y.csv" # mnist = TFLData( (train_file_path,test_file_path) ) train_file_path = os.path.join("mnist_csvs", "mnist_train.csv") test_file_path = os.path.join("mnist_csvs", "mnist_test.csv") mnist = MnistData((train_file_path, test_file_path, (1000, 1000))) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: #with tf.Session() as sess: # Input placeholders with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, FLAGS.image_dim * FLAGS.image_dim], name='x-input') y_ = tf.placeholder(tf.float32, [None, 10], name='y-input') keep_prob = tf.placeholder(tf.float32) with tf.variable_scope('model'): net = nn() inp = tf.pad( tf.reshape( x, [FLAGS.batch_size, FLAGS.image_dim, FLAGS.image_dim, 1]), [[0, 0], [2, 2], [2, 2], [0, 0]]) op = net.forward(inp) y = tf.squeeze(op) trainer = net.fit(output=y, ground_truth=y_, loss='softmax_crossentropy', optimizer='adam', opt_params=[FLAGS.learning_rate]) with tf.variable_scope('relevance'): if FLAGS.relevance: LRP = net.lrp(op, FLAGS.relevance_method, 1e-8) # LRP layerwise relevance_layerwise = [] # R = y # for layer in net.modules[::-1]: # R = net.lrp_layerwise(layer, R, 'simple') # relevance_layerwise.append(R) else: LRP = [] relevance_layerwise = [] with tf.name_scope('accuracy'): accuracy = tf.reduce_mean( tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32)) tf.summary.scalar('accuracy', accuracy) # Merge all the summaries and write them out to /tmp/mnist_logs (by default) merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph) test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test') tf.global_variables_initializer().run() utils = Utils(sess, FLAGS.checkpoint_reload_dir) if FLAGS.reload_model: utils.reload_model() for i in range(FLAGS.max_steps): if i % FLAGS.test_every == 0: # test-set accuracy d = feed_dict(mnist, False) test_inp = {x: d[0], y_: d[1], keep_prob: d[2]} #pdb.set_trace() summary, acc, relevance_test, rel_layer = sess.run( [merged, accuracy, LRP, relevance_layerwise], feed_dict=test_inp) print_y = tf.argmax(y, 1) y_labels = print_y.eval(feed_dict=test_inp) test_writer.add_summary(summary, i) print('Accuracy at step %s: %f' % (i, acc)) # print([np.sum(rel) for rel in rel_layer]) # print(np.sum(relevance_test)) # save model if required if FLAGS.save_model: utils.save_model() else: d = feed_dict(mnist, True) inp = {x: d[0], y_: d[1], keep_prob: d[2]} summary, _, relevance_train, op, rel_layer = sess.run( [merged, trainer.train, LRP, y, relevance_layerwise], feed_dict=inp) train_writer.add_summary(summary, i) # relevances plotted with visually pleasing color schemes if FLAGS.relevance: #pdb.set_trace() relevance_test = relevance_test[:, 2:FLAGS.image_dim + 2, 2:FLAGS.image_dim + 2, :] # plot test images with relevances overlaid images = test_inp[test_inp.keys()[0]].reshape( [FLAGS.batch_size, FLAGS.image_dim, FLAGS.image_dim, 1]) #images = (images + 1)/2.0 plot_relevances( relevance_test.reshape( [FLAGS.batch_size, FLAGS.image_dim, FLAGS.image_dim, 1]), images, test_writer, y_labels) # plot train images with relevances overlaid # relevance_train = relevance_train[:,2:30,2:30,:] # images = inp[inp.keys()[0]].reshape([FLAGS.batch_size,28,28,1]) # plot_relevances(relevance_train.reshape([FLAGS.batch_size,28,28,1]), images, train_writer ) train_writer.close() test_writer.close()
def train(tag): # Import data tag = tag sub = 'subset' + str(tag) x_train_whole = [] y_train_whole = [] if tag == 0 or tag == 1 or tag == 2: tot = 8 elif tag == 6: tot = 14 elif tag == 8: tot = 15 else: tot = 16 x_test_pos = [] x_test_neg = [] for num in range(tot): h5f = h5py.File('./src/data/3D_data/' + sub + '_' + str(num) + '.h5', 'r') y_tmp = np.asarray(h5f['Y']) x_tmp = np.asarray(h5f['X']) if max(y_tmp) != 0: x_tmp_pos = x_tmp[np.where(y_tmp == 1)[0], :, :, :, :] if x_test_pos == []: x_test_pos = x_tmp_pos else: x_test_pos = np.concatenate([x_test_pos, x_tmp_pos]) negIndex = np.random.choice(np.where(y_tmp == 0)[0], len(x_tmp_pos) * 3, replace=False) x_tmp_neg = x_tmp[negIndex, :, :, :, :] if x_test_neg == []: x_test_neg = x_tmp_neg else: x_test_neg = np.concatenate([x_test_neg, x_tmp_neg]) del x_tmp_pos del x_tmp_neg del negIndex del x_tmp del y_tmp y_test_pos = np.ones(len(x_test_pos)) y_test_neg = np.zeros(len(x_test_neg)) x_test_tmp = np.concatenate([x_test_pos, x_test_neg]) y_test_tmp = np.concatenate([y_test_pos, y_test_neg]) idx = np.arange(0, len(y_test_tmp)) np.random.shuffle(idx) x_test = np.asarray([x_test_tmp[i] for i in idx]) y_test = np.asarray([y_test_tmp[i] for i in idx]) del x_test_tmp del y_test_tmp del y_test_neg del x_test_neg del x_test_pos del y_test_pos print(len(x_test)) print(len(y_test)) sub = 'subset' for i in range(10): #for i in range(2): subset = sub + str(i) if i != tag: if i == 0 or i == 1 or i == 2: tot = 8 elif i == 6: tot = 14 elif i == 8: tot = 15 else: tot = 16 x_train_pos = [] x_train_neg = [] for num in range(tot): #for num in range(1): h5f2 = h5py.File( './src/data/3D_data/' + subset + '_' + str(num) + '.h5', 'r') x_tmp = np.asarray(h5f2['X']) y_tmp = np.asarray(h5f2['Y']) if max(y_tmp) != 0: x_tmp_pos = x_tmp[np.where(y_tmp == 1)[0], :, :, :, :] inp90 = np.zeros_like(x_tmp_pos) inp180 = np.zeros_like(x_tmp_pos) inp270 = np.zeros_like(x_tmp_pos) inp45 = np.zeros_like(x_tmp_pos) inp135 = np.zeros_like(x_tmp_pos) inp225 = np.zeros_like(x_tmp_pos) inp315 = np.zeros_like(x_tmp_pos) for aug in range(len(x_tmp_pos)): inp90[aug, :, :, :, :] = rotate( x_tmp_pos[aug, :, :, :, :], 90, reshape=False) inp180[aug, :, :, :, :] = rotate( x_tmp_pos[aug, :, :, :, :], 180, reshape=False) inp270[aug, :, :, :, :] = rotate( x_tmp_pos[aug, :, :, :, :], 270, reshape=False) inp45[aug, :, :, :, :] = rotate( x_tmp_pos[aug, :, :, :, :], 45, reshape=False) inp135[aug, :, :, :, :] = rotate( x_tmp_pos[aug, :, :, :, :], 135, reshape=False) inp225[aug, :, :, :, :] = rotate( x_tmp_pos[aug, :, :, :, :], 225, reshape=False) inp315[aug, :, :, :, :] = rotate( x_tmp_pos[aug, :, :, :, :], 315, reshape=False) tmp = np.concatenate([ np.concatenate([ np.concatenate([ np.concatenate([ np.concatenate([ np.concatenate([ np.concatenate([x_tmp_pos, inp90]), inp180 ]), inp270 ]), inp45 ]), inp135 ]), inp225 ]), inp315 ]) idx2 = np.arange(0, len(tmp)) np.random.shuffle(idx2) tmp2 = np.asarray([tmp[a] for a in idx2]) del inp90 del inp180 del inp270 del inp45 del inp135 del inp225 del inp315 if x_train_pos == []: x_train_pos = tmp2[0:int(len(tmp) / 4), :, :, :, :] else: x_train_pos = np.concatenate([ x_train_pos, tmp2[0:int(len(tmp) / 5), :, :, :, :] ]) del tmp negIndex = np.random.choice(np.where(y_tmp == 0)[0], len(x_tmp_pos) * 5, replace=False) x_tmp_neg = x_tmp[negIndex, :, :, :, :] if x_train_neg == []: x_train_neg = x_tmp_neg else: x_train_neg = np.concatenate([x_train_neg, x_tmp_neg]) del tmp2 del x_tmp_neg del x_tmp_pos del negIndex del x_tmp del y_tmp y_train_pos = np.ones(len(x_train_pos)) y_train_neg = np.zeros(len(x_train_neg)) x_train_tmp = np.concatenate([x_train_pos, x_train_neg]) y_train_tmp = np.concatenate([y_train_pos, y_train_neg]) del x_train_pos del x_train_neg del y_train_neg del y_train_pos idx = np.arange(0, len(y_train_tmp)) np.random.shuffle(idx) x_train = np.asarray([x_train_tmp[a] for a in idx]) y_train = np.asarray([y_train_tmp[a] for a in idx]) del x_train_tmp del y_train_tmp if x_train_whole == []: x_train_whole = x_train y_train_whole = y_train else: x_train_whole = np.concatenate([x_train_whole, x_train]) y_train_whole = np.concatenate([y_train_whole, y_train]) print(len(x_train_whole)) del x_train del y_train x_train = x_train_whole y_train = y_train_whole del x_train_whole del y_train_whole print(len(x_train)) print(len(y_train)) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: # with tf.Session() as sess: # Input placeholders with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, 32, 32, 32, 1], name='x-input') y_ = tf.placeholder(tf.float32, [None, 2], name='y-input') phase = tf.placeholder(tf.bool, name='phase') with tf.variable_scope('model'): net = nn(phase) # x_prep = prep_data_augment(x) # x_input = data_augment(x_prep) inp = tf.reshape(x, [FLAGS.batch_size, 32, 32, 32, 1]) op = net.forward(inp) y = tf.reshape(op, [FLAGS.batch_size, 2]) soft = tf.nn.softmax(y) trainer = net.fit(output=y, ground_truth=y_, loss='focal loss', optimizer='adam', opt_params=[FLAGS.learning_rate]) with tf.variable_scope('relevance'): if FLAGS.relevance: LRP = net.lrp(y, FLAGS.relevance_method, 1) # LRP layerwise relevance_layerwise = [] # R = input_rel2 # for layer in net.modules[::-1]: # R = net.lrp_layerwise(layer, R, FLAGS.relevance_method, 1e-8) # relevance_layerwise.append(R) else: LRP = [] relevance_layerwise = [] with tf.name_scope('accuracy'): accuracy = tf.reduce_mean( tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32)) # accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(tf.where(tf.greater(y,0),tf.ones_like(y, dtype=tf.float32), tf.zeros_like(y, dtype=tf.float32)), 2), tf.argmax(y_, 2)), tf.float32)) tf.summary.scalar('accuracy', accuracy) # Merge all the summaries and write them out to /tmp/mnist_logs (by default) merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter( './conv_log/' + str(tag) + '_train', sess.graph) test_writer = tf.summary.FileWriter('./conv_log/' + str(tag) + '_test') tf.global_variables_initializer().run() utils = Utils(sess, './3D_model/subset' + str(tag)) if FLAGS.reload_model: utils.reload_model() train_acc = [] test_acc = [] for i in range(FLAGS.max_steps): if i % FLAGS.test_every == 0: # test-set accuracy x_test_batch, y_test_batch = next_batch( FLAGS.batch_size, x_test, y_test) tmp_y_batch = np.zeros([FLAGS.batch_size, 2]) tmp_y_batch[:, 0] = np.ones([FLAGS.batch_size]) - y_test_batch tmp_y_batch[:, 1] = np.zeros([FLAGS.batch_size]) + y_test_batch y_test_batch = tmp_y_batch test_inp = {x: x_test_batch, y_: y_test_batch, phase: False} # pdb.set_trace() summary, acc, relevance_test, op2, soft_val, rel_layer = sess.run( [merged, accuracy, LRP, y, soft, relevance_layerwise], feed_dict=test_inp) test_writer.add_summary(summary, i) test_acc.append(acc) print('-----------') for m in range(FLAGS.batch_size): print(np.argmax(y_test_batch[m, :]), y_test_batch[m, :], end=" ") print(np.argmax(op2[m, :]), op2[m, :], end=" ") print(soft_val[m, :]) print("|") print('Accuracy at step %s: %f' % (i, acc)) print(tag) # print([np.sum(rel) for rel in rel_layer]) # print(np.sum(relevance_test)) # save model if required if FLAGS.save_model: utils.save_model() else: x_train_batch, y_train_batch = next_batch( FLAGS.batch_size, x_train, y_train) tmp_y_batch = np.zeros([FLAGS.batch_size, 2]) tmp_y_batch[:, 0] = np.ones([FLAGS.batch_size]) - y_train_batch tmp_y_batch[:, 1] = np.zeros([FLAGS.batch_size]) + y_train_batch y_train_batch = tmp_y_batch inp = {x: x_train_batch, y_: y_train_batch, phase: True} summary, acc2, _, relevance_train, op2, soft_val, rel_layer = sess.run( [ merged, accuracy, trainer.train, LRP, y, soft, relevance_layerwise ], feed_dict=inp) train_writer.add_summary(summary, i) #print(soft_val[0,:]) train_acc.append(acc2) print(np.mean(train_acc), np.mean(test_acc)) # relevances plotted with visually pleasing color schemes if FLAGS.relevance: # plot test images with relevances overlaid images = test_inp[test_inp.keys()[0]].reshape( [FLAGS.batch_size, 32, 32, 32, 1]) # images = (images + 1)/2.0 plot_relevances( relevance_test.reshape([FLAGS.batch_size, 32, 32, 32, 1]), images, test_writer) train_writer.close() test_writer.close()
def train(tag): tag = str(tag) tag = int(tag) x_test_batch = np.load('./Demo_img/test_demo_x.npy') y_test_batch = np.load('./Demo_img/test_demo_y.npy') config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: # with tf.Session() as sess: # Input placeholders with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, 32, 32, 32, 1], name='x-input') y_ = tf.placeholder(tf.float32, [None, 2], name='y-input') phase = tf.placeholder(tf.bool, name='phase') with tf.variable_scope('model'): net = nn(phase) # x_prep = prep_data_augment(x) # x_input = data_augment(x_prep) inp = tf.reshape(x, [FLAGS.batch_size, 32, 32, 32, 1]) op = net.forward(inp) y = tf.reshape(op, [FLAGS.batch_size, 2]) soft = tf.nn.softmax(y) with tf.variable_scope('relevance'): if FLAGS.relevance: LRP = net.lrp(soft, FLAGS.relevance_method, 2) # LRP layerwise relevance_layerwise = [] #R = tf.expand_dims(soft[0, :], 0) R = soft for layer in net.modules[::-1]: R = net.lrp_layerwise(layer, R, FLAGS.relevance_method, 2) relevance_layerwise.append(R) else: LRP = [] relevance_layerwise = [] with tf.name_scope('accuracy'): accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)), tf.float32)) # accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(tf.where(tf.greater(y,0),tf.ones_like(y, dtype=tf.float32), tf.zeros_like(y, dtype=tf.float32)), 2), tf.argmax(y_, 2)), tf.float32)) tf.summary.scalar('accuracy', accuracy) # Merge all the summaries and write them out to /tmp/mnist_logs (by default) merged = tf.summary.merge_all() test_writer = tf.summary.FileWriter('./conv_log/LRP') tf.global_variables_initializer().run() utils = Utils(sess, './3D_model/subset'+str(tag)) if FLAGS.reload_model: utils.reload_model() test_inp = {x: x_test_batch, y_: y_test_batch, phase: False} # pdb.set_trace() relevance_test, op, soft_val, rel_layer = sess.run([LRP, y, soft, relevance_layerwise], feed_dict=test_inp) for m in range(FLAGS.batch_size): print(soft_val[m, :]) np.save('./Demo_img/soft.npy',soft_val) if FLAGS.relevance: # plot test images with relevances overlaid images = test_inp[test_inp.keys()[0]].reshape([FLAGS.batch_size, 32, 32, 32, 1]) # images = (images + 1)/2.0 plot_relevances(relevance_test.reshape([FLAGS.batch_size, 32, 32, 32, 1]), images, test_writer) test_writer.close()