Exemplo n.º 1
0
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    discriminator = Discriminator(directory_ad).load()
    name = "generated_actions.csv"

    N = discriminator.net.input_shape[1]
    lowbit = 20
    highbit = N - lowbit
    print("batch size: {}".format(2**lowbit))

    xs = (((np.arange(2**lowbit)[:, None] &
            (1 << np.arange(N)))) > 0).astype(int)
    # xs_h = (((np.arange(2**highbit)[:,None] & (1 << np.arange(highbit)))) > 0).astype(int)

    try:
        print(discriminator.local(name))
        with open(discriminator.local(name), 'wb') as f:
            for i in range(2**highbit):
                print("Iteration {}/{} base: {}".format(
                    i, 2**highbit, i * (2**lowbit)),
                      end=' ')
                # h = np.binary_repr(i*(2**lowbit), width=N)
                # print(h)
                # xs_h = np.unpackbits(np.array([i*(2**lowbit)],dtype=int))
                xs_h = (((np.array([i])[:, None] &
                          (1 << np.arange(highbit)))) > 0).astype(int)
                xs[:, lowbit:] = xs_h
                # print(xs_h)
                # print(xs[:10])
                ys = discriminator.discriminate(xs, batch_size=100000)
                ind = np.where(ys > 0.5)
                valid_xs = xs[ind]
                print(len(valid_xs))
                np.savetxt(f, valid_xs, "%d")
    except KeyboardInterrupt:
        print("dump stopped")
Exemplo n.º 2
0
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    print("loading the Discriminator", end='...', flush=True)
    ad = Discriminator(directory_ad).load()
    print("done.")

    # valid_states  = load("{}/states.csv".format(directory))
    valid_actions = load("{}/actions.csv".format(directory))
    threshold = maxdiff(valid_actions)
    print("maxdiff:", threshold)

    states = load("{}/generated_states.csv".format(directory))

    path = "{}/generated_actions.csv".format(directory)

    total = states.shape[0]
    N = states.shape[1]
    acc = 0

    try:
        print(path)
        with open(path, 'wb') as f:
            for i, s in enumerate(states):
                print("Iteration {}/{} base: {}".format(i, total, i * total),
                      end=' ')
                diff = np.sum(np.abs(states - s), axis=1)
                neighbors = states[np.where(diff < threshold)]
                tmp_actions = np.pad(neighbors, ((0, 0), (0, N)), "constant")
                tmp_actions[:, N:] = s
                ys = ad.discriminate(tmp_actions, batch_size=400000)
                valid_actions = tmp_actions[np.where(ys > 0.8)]
                acc += len(valid_actions)
                print(len(neighbors), len(valid_actions), acc)
                np.savetxt(f, valid_actions, "%d")
    except KeyboardInterrupt:
        print("dump stopped")
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_ad = "{}_ad/".format(directory)
    print("loading the Discriminator", end='...', flush=True)
    ad = Discriminator(directory_ad).load()
    print("done.")
    name = "generated_actions.csv"

    print("loading {}".format("{}/generated_states2.csv".format(directory)),
          end='...',
          flush=True)
    states = np.loadtxt("{}/generated_states2.csv".format(directory),
                        dtype=np.uint8)
    print("done.")
    total = states.shape[0]
    N = states.shape[1]
    actions = np.pad(states, ((0, 0), (0, N)), "constant")

    acc = 0

    try:
        print(ad.local(name))
        with open(ad.local(name), 'wb') as f:
            for i, s in enumerate(states):
                print("Iteration {}/{} base: {}".format(i, total, i * total),
                      end=' ')
                actions[:, N:] = s
                ys = ad.discriminate(actions, batch_size=400000)
                valid_actions = actions[np.where(ys > 0.8)]
                acc += len(valid_actions)
                print(len(valid_actions), acc)
                np.savetxt(f, valid_actions, "%d")
    except KeyboardInterrupt:
        print("dump stopped")
Exemplo n.º 4
0
def main():
    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    sd = Discriminator("{}/_sd".format(directory)).load()
    ae = ConvolutionalGumbelAE2(directory).load()

    input = "generated_states.csv"
    print("loading {}".format("{}/{}".format(directory, input)),
          end='...',
          flush=True)
    states = np.loadtxt("{}/{}".format(directory, input), dtype=np.uint8)
    print("done.")
    zs = states.view()
    total = states.shape[0]
    N = states.shape[1]
    batch = 500000
    output = "generated_states2.csv"
    try:
        print(ae.local(output))
        with open(ae.local(output), 'wb') as f:
            print("original states:", total)
            for i in range(total // batch + 1):
                _zs = zs[i * batch:(i + 1) * batch]
                _result = sd.discriminate(_zs, batch_size=5000).round().astype(
                    np.uint8)
                _zs_filtered = _zs[np.where(_result > 0)[0], :]
                print("reduced  states:", len(_zs_filtered), "/", len(_zs))

                _xs = ae.decode_binary(_zs_filtered[:20],
                                       batch_size=5000).round().astype(
                                           np.uint8)
                ae.plot(_xs, path="generated_states_filtered{}.png".format(i))

                np.savetxt(f, _zs_filtered, "%d", delimiter=" ")

    except KeyboardInterrupt:
        print("dump stopped")
def main():
    import numpy.random as random
    from trace import trace

    import sys
    if len(sys.argv) == 1:
        sys.exit("{} [directory]".format(sys.argv[0]))

    directory = sys.argv[1]
    directory_sd = "{}/_sd/".format(directory)
    sd = Discriminator(directory_sd).load()
    name = "generated_states.csv"

    N = sd.net.input_shape[1]
    lowbit = 20
    highbit = N - lowbit
    print("batch size: {}".format(2**lowbit))

    xs = (((np.arange(2**lowbit)[:, None] &
            (1 << np.arange(N)))) > 0).astype(int)

    try:
        print(sd.local(name))
        with open(sd.local(name), 'wb') as f:
            for i in range(2**highbit):
                print("Iteration {}/{} base: {}".format(
                    i, 2**highbit, i * (2**lowbit)),
                      end=' ')
                xs_h = (((np.array([i])[:, None] &
                          (1 << np.arange(highbit)))) > 0).astype(int)
                xs[:, lowbit:] = xs_h
                # print(xs_h)
                # print(xs[:10])
                ys = sd.discriminate(xs, batch_size=100000)
                ind = np.where(ys > 0.8)
                valid_xs = xs[ind[0], :]
                print(len(valid_xs))
                np.savetxt(f, valid_xs, "%d", delimiter=" ")
    except KeyboardInterrupt:
        print("dump stopped")
Exemplo n.º 6
0
        # quick eval
        'epoch': [1000],
        'lr': [0.0001],
    }

    train = True
    if train:
        discriminator, _, _ = grid_search(directory_sd, train_in, train_out,
                                          test_in, test_out)
    else:
        discriminator = Discriminator(directory_sd).load()
    print("index, discrimination, action")
    show_n = 30

    for y, _y in zip(
            discriminator.discriminate(test_in)[:show_n], test_out[:show_n]):
        print(y, _y)

    # test if the learned action is correct

    states_valid = np.loadtxt("{}/all_states.csv".format(directory), dtype=int)
    print("valid", states_valid.shape)

    discriminator.report(states_valid,
                         train_data_to=np.ones((len(states_valid), )))
"""

* Summary:

Input: a subset of valid states and states generated by EB-discriminator
Output: a function that returns 0/1 for a state