def run_preparing(g, sess, args): input_data = expl_graph.load_input_data(args.data) graph, options = expl_graph.load_explanation_graph(args.expl_graph, args.flags) flags = Flags(args, options) flags.update() ## loss_loader = expl_graph.LossLoader() loss_loader.load_all("loss/") loss_cls = loss_loader.get_loss(flags.sgd_loss) ## tensor_provider = expl_graph.SwitchTensorProvider() embedding_generator = None if flags.embedding: embedding_generator = expl_graph.EmbeddingGenerator() embedding_generator.load(flags.embedding) tensor_embedding = tensor_provider.build( graph, options, input_data, flags, load_embeddings=False, embedding_generator=embedding_generator, )
def run_test(g, sess, args): if args.data is not None: input_data = expl_graph.load_input_data(args.data) else: input_data = None graph, options = expl_graph.load_explanation_graph(args.expl_graph, args.flags) flags = Flags(args, options) flags.update() ## loss_loader = expl_graph.LossLoader() loss_loader.load_all("loss/") loss_cls = loss_loader.get_loss(flags.sgd_loss) ## tensor_provider = expl_graph.SwitchTensorProvider() embedding_generator = None if flags.embedding: embedding_generator = expl_graph.EmbeddingGenerator() embedding_generator.load(flags.embedding, key="test") cycle_embedding_generator = None if flags.cycle: cycle_embedding_generator = expl_graph.CycleEmbeddingGenerator() cycle_embedding_generator.load(options) tensor_embedding = tensor_provider.build( graph, options, input_data, flags, load_embeddings=True, embedding_generator=embedding_generator, ) comp_expl_graph=expl_graph.ComputationalExplGraph() goal_inside = comp_expl_graph.build_explanation_graph( graph, tensor_provider, cycle_embedding_generator ) if input_data is not None: goal_dataset = build_goal_dataset(input_data, tensor_provider) else: goal_dataset = None if flags.draw_graph: save_draw_graph(g, "test") loss, output = loss_cls().call(graph, goal_inside, tensor_provider) saver = tf.train.Saver() saver.restore(sess, flags.model) if flags.cycle: optimize_solve( sess, goal_dataset, goal_inside, flags, [embedding_generator, cycle_embedding_generator], ) elif goal_dataset is not None: ### dataset is given (minibatch) batch_size = flags.sgd_minibatch_size total_loss = [[] for _ in range(len(goal_dataset))] total_output = [[] for _ in range(len(goal_dataset))] for j, goal in enumerate(goal_dataset): ph_vars = goal["placeholders"] dataset = goal["dataset"] num = dataset.shape[1] num_itr = (num + batch_size - 1) // batch_size if not flags.no_verb: progbar = tf.keras.utils.Progbar(num_itr) idx = list(range(num)) for itr in range(num_itr): temp_idx = idx[itr * batch_size : (itr + 1) * batch_size] if len(temp_idx) < batch_size: padding_idx = np.zeros((batch_size,), dtype=np.int32) padding_idx[: len(temp_idx)] = temp_idx feed_dict = { ph: dataset[i, padding_idx] for i, ph in enumerate(ph_vars) } else: feed_dict = { ph: dataset[i, temp_idx] for i, ph in enumerate(ph_vars) } if embedding_generator: feed_dict = embedding_generator.build_feed(feed_dict) batch_loss, batch_output = sess.run( [loss[j], output[j]], feed_dict=feed_dict ) if not flags.no_verb: progbar.update(itr) # print(batch_output.shape) # batch_output=np.transpose(batch_output) total_loss[j].extend(batch_loss[: len(temp_idx)]) total_output[j].extend(batch_output[: len(temp_idx)]) print("loss:", np.mean(total_loss[j])) print("output:", np.array(total_output[j]).shape) else: feed_dict = {} total_loss = [] total_output = [] if embedding_generator: feed_dict = embedding_generator.build_feed(feed_dict) for j in range(len(loss)): j_loss, j_output = sess.run([loss[j], output[j]], feed_dict=feed_dict) total_loss.append(j_loss) total_output.append(j_output) ### print("loss:", np.mean(total_loss)) print("output:", np.array(total_output).shape) total_goal_inside = [] for g in goal_inside: g_inside = sess.run([g['inside']], feed_dict=feed_dict) total_goal_inside.append(g_inside[0]) ### print("[SAVE]", flags.output) np.save(flags.output, total_output) data={} for g_info,g in zip(graph.goals, total_goal_inside): gg=g_info.node.goal name=to_string_goal(gg) data[g_info.node.id]={"name":name,"data":g} fp = open('output.pkl','wb') pickle.dump(data,fp)
def run_training(g, sess, args): if args.data is not None: input_data = expl_graph.load_input_data(args.data) else: input_data = None graph, options = expl_graph.load_explanation_graph(args.expl_graph, args.flags) flags = Flags(args, options) flags.update() ## loss_loader = expl_graph.LossLoader() loss_loader.load_all("loss/") loss_cls = loss_loader.get_loss(flags.sgd_loss) ## tensor_provider = expl_graph.SwitchTensorProvider() embedding_generator = None if flags.embedding: embedding_generator = expl_graph.EmbeddingGenerator() embedding_generator.load(flags.embedding) cycle_embedding_generator = None if flags.cycle: cycle_embedding_generator = expl_graph.CycleEmbeddingGenerator() cycle_embedding_generator.load(options) tensor_embedding = tensor_provider.build( graph, options, input_data, flags, load_embeddings=False, embedding_generator=embedding_generator, ) comp_expl_graph=expl_graph.ComputationalExplGraph() goal_inside = comp_expl_graph.build_explanation_graph( graph, tensor_provider, cycle_embedding_generator ) if input_data is not None: goal_dataset = build_goal_dataset(input_data, tensor_provider) else: goal_dataset = None if flags.draw_graph: save_draw_graph(g, "test") loss, output = loss_cls().call(graph, goal_inside, tensor_provider) if loss: with tf.name_scope("summary"): tf.summary.scalar("loss", loss) merged = tf.summary.merge_all() writer = tf.summary.FileWriter("./tf_logs", sess.graph) ## print("traing start") vars_to_train = tf.trainable_variables() for var in vars_to_train: print("train var:", var.name, var.shape) ## start_t = time.time() if flags.cycle: optimize_solve( sess, goal_dataset, goal_inside, flags, [embedding_generator, cycle_embedding_generator], ) elif goal_dataset is not None: optimize( sess, goal_dataset, loss, flags, [embedding_generator, cycle_embedding_generator], ) else: optimize_sgd( sess, goal_dataset, loss, flags, [embedding_generator, cycle_embedding_generator], ) train_time = time.time() - start_t print("traing time:{0}".format(train_time) + "[sec]")