Exemplo n.º 1
0

def export_attention(model, data, name="transformer"):
    with open("./visual/tmp/transformer_v2i_i2v.pkl", "rb") as f:
        dic = pickle.load(f)
    model.load_weights("./visual/models/transformer/model.ckpt")
    bx, by, seq_len = data.sample(32)
    model.translate(bx, dic["v2i"], dic["i2v"])
    attn_data = {
        "src": [[data.i2v[i] for i in bx[j]] for j in range(len(bx))],
        "tgt": [[data.i2v[i] for i in by[j]] for j in range(len(by))],
        "attentions": model.attentions
    }
    path = "./visual/tmp/%s_attention_matrix.pkl" % name
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "wb") as f:
        pickle.dump(attn_data, f)


if __name__ == "__main__":
    utils.set_soft_gpu(True)
    d = utils.DateData(4000)
    print("Chinese time order: yy/mm/dd ", d.date_cn[:3],
          "\nEnglish time order: dd/M/yyyy ", d.date_en[:3])
    print("vocabularies: ", d.vocab)
    print("x index sample: \n{}\n{}".format(d.idx2str(d.x[0]), d.x[0]),
          "\ny index sample: \n{}\n{}".format(d.idx2str(d.y[0]), d.y[0]))

    m = Transformer(MODEL_DIM, MAX_LEN, N_LAYER, N_HEAD, d.num_word, DROP_RATE)
    train(m, d, step=800)
    export_attention(m, d)
Exemplo n.º 2
0

def init_logger(date_str, m):
    logger = utils.get_logger(date_str)
    logger.info(str(args))
    logger.info("model parameters: g={}, d={}".format(
        m.g.count_params(), m.d.count_params()))

    try:
        tf.keras.utils.plot_model(m.g, show_shapes=True, expand_nested=True, dpi=150,
                                  to_file="visual/{}/net_g.png".format(date_str))
        tf.keras.utils.plot_model(m.d, show_shapes=True, expand_nested=True, dpi=150,
                                  to_file="visual/{}/net_d.png".format(date_str))
    except Exception as e:
        print(e)
    return logger


if __name__ == "__main__":
    utils.set_soft_gpu(args.soft_gpu)

    summary_writer = tf.summary.create_file_writer('visual/{}'.format(date_str))
    d = load_tfrecord(args.batch_size//2, args.data_dir)
    m = StyleGAN(
        img_shape=(args.image_size, args.image_size, 3), latent_dim=args.latent, summary_writer=summary_writer,
        lr=args.lr, beta1=args.beta1, beta2=args.beta2, lambda_=args.lambda_, wgan=args.wgan)
    logger = init_logger(date_str, m)
    train(m, d)


Exemplo n.º 3
0
    t0 = time.time()
    for ep in range(EPOCH):
        for t, (img, _) in enumerate(ds):
            g_img, d_loss, d_acc, g_loss, g_acc = gan.step(img)
            if t % 400 == 0:
                t1 = time.time()
                print(
                    "ep={} | time={:.1f} | t={} | d_acc={:.2f} | g_acc={:.2f} | d_loss={:.2f} | g_loss={:.2f}"
                    .format(
                        ep,
                        t1 - t0,
                        t,
                        d_acc.numpy(),
                        g_acc.numpy(),
                        d_loss.numpy(),
                        g_loss.numpy(),
                    ))
                t0 = t1
        save_gan(gan, ep)
    save_weights(gan)


if __name__ == "__main__":
    LATENT_DIM = 100
    IMG_SHAPE = (28, 28, 1)
    BATCH_SIZE = 64
    EPOCH = 20

    set_soft_gpu(True)
    train()