示例#1
0
                                             real_seqlen: real_batch[1]})

    fake_batch = fake_iterator.next_batch(BATCH_SIZE)
    G_loss_curr, Summary_curr, _ = sess.run([G_loss, merged, gen_train_op],
                                            feed_dict={Z: fake_batch[0],
                                                       fake_seqlen: fake_batch[
                                                           1]})

    # train_writer.add_summary(Summary_curr,global_step=it)

    if it == 0:
        if REAL_DATA:
            pass
            # integral_intensity = get_integral_empirical(real_sequences, intensity_real,T,n_t)
        elif DATA != "rmtpp":
            integral_intensity = get_integral(real_sequences, DATA)
            integral_intensity = np.asarray(integral_intensity)
            fig = sm.qqplot(integral_intensity, stats.expon,
                            distargs=(), loc=0, scale=1, line='45')
            plt.grid()
            fig.savefig('logs/out/{}/real.png'.format(saved_file))
            plt.close()

    if it % 1000 == 0:
        sequences_generator = []
        for _ in range(int(1000 / BATCH_SIZE)):
            sequences_gen = sess.run(fake_data, feed_dict={
                Z: fake_batch[0], fake_seqlen: fake_batch[1]})
            shape_gen = sequences_gen.shape
            sequences_gen = np.reshape(
                sequences_gen, (shape_gen[0], shape_gen[1]))
示例#2
0
文件: MLE.py 项目: zhh0998/pp
        deviation = np.linalg.norm(
            intensity_gen - intensity_real) / np.linalg.norm(intensity_real)
        # can use correlation or other metric
        print('Iter: {};  deviation: {}'.format(it, deviation))
        plt.plot(ts_real, intensity_real, label='real')
        plt.plot(ts_gen, intensity_gen, label='generated')
        plt.legend(loc=1)
        plt.xlabel('time')
        plt.ylabel('intensity')
        plt.savefig('out/{}/{}_{}.png'.format(saved_file,
                                              str(it).zfill(3), deviation),
                    bbox_inches='tight')
        plt.close()

        if not REAL_DATA and DATA != "rmtpp":
            integral_intensity = get_integral(generated_sequences, DATA)
            integral_intensity = np.asarray(integral_intensity)
            fig = plt.figure()
            left = -1.8  #x coordinate for text insert
            ax1 = fig.add_subplot(1, 2, 1)
            fig = sm.qqplot(integral_intensity,
                            stats.expon,
                            distargs=(),
                            loc=0,
                            scale=1,
                            line='45',
                            ax=ax1)
            plt.grid()
            ax2 = fig.add_subplot(1, 2, 2)
            top = ax2.get_ylim()[1] * 0.75
            res, slope_intercept = stats.probplot(integral_intensity,