Пример #1
0
def run_preparing(g, sess, args):
    input_data = expl_graph.load_input_data(args.data)
    graph, options = expl_graph.load_explanation_graph(args.expl_graph, args.flags)
    flags = Flags(args, options)
    flags.update()
    ##
    loss_loader = expl_graph.LossLoader()
    loss_loader.load_all("loss/")
    loss_cls = loss_loader.get_loss(flags.sgd_loss)
    ##
    tensor_provider = expl_graph.SwitchTensorProvider()
    embedding_generator = None
    if flags.embedding:
        embedding_generator = expl_graph.EmbeddingGenerator()
        embedding_generator.load(flags.embedding)
    tensor_embedding = tensor_provider.build(
        graph,
        options,
        input_data,
        flags,
        load_embeddings=False,
        embedding_generator=embedding_generator,
    )
Пример #2
0
def run_test(g, sess, args):
    if args.data is not None:
        input_data = expl_graph.load_input_data(args.data)
    else:
        input_data = None
    graph, options = expl_graph.load_explanation_graph(args.expl_graph, args.flags)
    flags = Flags(args, options)
    flags.update()
    ##
    loss_loader = expl_graph.LossLoader()
    loss_loader.load_all("loss/")
    loss_cls = loss_loader.get_loss(flags.sgd_loss)
    ##
    tensor_provider = expl_graph.SwitchTensorProvider()
    embedding_generator = None
    if flags.embedding:
        embedding_generator = expl_graph.EmbeddingGenerator()
        embedding_generator.load(flags.embedding, key="test")
    cycle_embedding_generator = None
    if flags.cycle:
        cycle_embedding_generator = expl_graph.CycleEmbeddingGenerator()
        cycle_embedding_generator.load(options)
    tensor_embedding = tensor_provider.build(
        graph,
        options,
        input_data,
        flags,
        load_embeddings=True,
        embedding_generator=embedding_generator,
    )
    comp_expl_graph=expl_graph.ComputationalExplGraph()
    goal_inside = comp_expl_graph.build_explanation_graph(
        graph, tensor_provider, cycle_embedding_generator
    )
    if input_data is not None:
        goal_dataset = build_goal_dataset(input_data, tensor_provider)
    else:
        goal_dataset = None
    if flags.draw_graph:
        save_draw_graph(g, "test")
    loss, output = loss_cls().call(graph, goal_inside, tensor_provider)
    saver = tf.train.Saver()
    saver.restore(sess, flags.model)

    if flags.cycle:
        optimize_solve(
            sess,
            goal_dataset,
            goal_inside,
            flags,
            [embedding_generator, cycle_embedding_generator],
        )
    elif goal_dataset is not None:
        ### dataset is given (minibatch)
        batch_size = flags.sgd_minibatch_size
        total_loss = [[] for _ in range(len(goal_dataset))]
        total_output = [[] for _ in range(len(goal_dataset))]
        for j, goal in enumerate(goal_dataset):
            ph_vars = goal["placeholders"]
            dataset = goal["dataset"]
            num = dataset.shape[1]
            num_itr = (num + batch_size - 1) // batch_size
            if not flags.no_verb:
                progbar = tf.keras.utils.Progbar(num_itr)
            idx = list(range(num))
            for itr in range(num_itr):
                temp_idx = idx[itr * batch_size : (itr + 1) * batch_size]
                if len(temp_idx) < batch_size:
                    padding_idx = np.zeros((batch_size,), dtype=np.int32)
                    padding_idx[: len(temp_idx)] = temp_idx
                    feed_dict = {
                        ph: dataset[i, padding_idx] for i, ph in enumerate(ph_vars)
                    }
                else:
                    feed_dict = {
                        ph: dataset[i, temp_idx] for i, ph in enumerate(ph_vars)
                    }

                if embedding_generator:
                    feed_dict = embedding_generator.build_feed(feed_dict)
                batch_loss, batch_output = sess.run(
                    [loss[j], output[j]], feed_dict=feed_dict
                )
                if not flags.no_verb:
                    progbar.update(itr)
                # print(batch_output.shape)
                # batch_output=np.transpose(batch_output)
                total_loss[j].extend(batch_loss[: len(temp_idx)])
                total_output[j].extend(batch_output[: len(temp_idx)])
            print("loss:", np.mean(total_loss[j]))
            print("output:", np.array(total_output[j]).shape)
    else:
        feed_dict = {}
        total_loss = []
        total_output = []
        if embedding_generator:
            feed_dict = embedding_generator.build_feed(feed_dict)
        for j in range(len(loss)):
            j_loss, j_output = sess.run([loss[j], output[j]], feed_dict=feed_dict)
            total_loss.append(j_loss)
            total_output.append(j_output)
            ###
        print("loss:", np.mean(total_loss))
        print("output:", np.array(total_output).shape)
        total_goal_inside = []
        for g in goal_inside:
            g_inside = sess.run([g['inside']], feed_dict=feed_dict)
            total_goal_inside.append(g_inside[0])

    ###
    print("[SAVE]", flags.output)
    np.save(flags.output, total_output)
    data={}
    for g_info,g in zip(graph.goals, total_goal_inside):
        gg=g_info.node.goal
        name=to_string_goal(gg)
        data[g_info.node.id]={"name":name,"data":g}
    fp = open('output.pkl','wb')
    pickle.dump(data,fp)
Пример #3
0
def run_training(g, sess, args):
    if args.data is not None:
        input_data = expl_graph.load_input_data(args.data)
    else:
        input_data = None
    graph, options = expl_graph.load_explanation_graph(args.expl_graph, args.flags)
    flags = Flags(args, options)
    flags.update()
    ##
    loss_loader = expl_graph.LossLoader()
    loss_loader.load_all("loss/")
    loss_cls = loss_loader.get_loss(flags.sgd_loss)
    ##
    tensor_provider = expl_graph.SwitchTensorProvider()
    embedding_generator = None
    if flags.embedding:
        embedding_generator = expl_graph.EmbeddingGenerator()
        embedding_generator.load(flags.embedding)
    cycle_embedding_generator = None
    if flags.cycle:
        cycle_embedding_generator = expl_graph.CycleEmbeddingGenerator()
        cycle_embedding_generator.load(options)
    tensor_embedding = tensor_provider.build(
        graph,
        options,
        input_data,
        flags,
        load_embeddings=False,
        embedding_generator=embedding_generator,
    )
    comp_expl_graph=expl_graph.ComputationalExplGraph()
    goal_inside = comp_expl_graph.build_explanation_graph(
        graph, tensor_provider, cycle_embedding_generator
    )
    if input_data is not None:
        goal_dataset = build_goal_dataset(input_data, tensor_provider)
    else:
        goal_dataset = None

    if flags.draw_graph:
        save_draw_graph(g, "test")

    loss, output = loss_cls().call(graph, goal_inside, tensor_provider)
    if loss:
        with tf.name_scope("summary"):
            tf.summary.scalar("loss", loss)
            merged = tf.summary.merge_all()
            writer = tf.summary.FileWriter("./tf_logs", sess.graph)
    ##
    print("traing start")
    vars_to_train = tf.trainable_variables()
    for var in vars_to_train:
        print("train var:", var.name, var.shape)
    ##
    start_t = time.time()
    if flags.cycle:
        optimize_solve(
            sess,
            goal_dataset,
            goal_inside,
            flags,
            [embedding_generator, cycle_embedding_generator],
        )
    elif goal_dataset is not None:
        optimize(
            sess,
            goal_dataset,
            loss,
            flags,
            [embedding_generator, cycle_embedding_generator],
        )
    else:
        optimize_sgd(
            sess,
            goal_dataset,
            loss,
            flags,
            [embedding_generator, cycle_embedding_generator],
        )
    train_time = time.time() - start_t
    print("traing time:{0}".format(train_time) + "[sec]")