def fit_network_hawkes_svi(S, K, C, dt, dt_max, output_path, standard_model=None, N_iters=500, true_network=None):

    # Check for existing Gibbs results
    if os.path.exists(output_path + ".svi.pkl.gz"):
        with gzip.open(output_path + ".svi.pkl.gz", "r") as f:
            print "Loading SVI results from ", (output_path + ".svi.pkl.gz")
            (samples, timestamps) = cPickle.load(f)
    elif os.path.exists(output_path + ".svi.itr%04d.pkl" % (N_iters - 1)):
        with open(output_path + ".svi.itr%04d.pkl" % (N_iters - 1), "r") as f:
            print "Loading SVI results from ", (output_path + ".svi.itr%04d.pkl" % (N_iters - 1))
            sample = cPickle.load(f)
            samples = [sample]
            timestamps = None
            # (samples, timestamps) = cPickle.load(f)

    else:
        print "Fitting the data with a network Hawkes model using SVI"

        # Make a new model for inference
        test_basis = IdentityBasis(dt, dt_max, allow_instantaneous=True)
        E_W = 0.01
        kappa = 10.0
        E_v = kappa / E_W
        alpha = 10.0
        beta = alpha / E_v
        # network_hypers = {'C': 2,
        #                   'kappa': kappa, 'alpha': alpha, 'beta': beta,
        #                   'p': 0.1, 'tau1': 1.0, 'tau0': 1.0,
        #                   'allow_self_connections': False}
        # test_model = DiscreteTimeNetworkHawkesModelGammaMixture(K=K, dt=dt, dt_max=dt_max,
        #                                                         basis=test_basis,
        #                                                         network_hypers=network_hypers)

        network_hypers = {
            "C": 2,
            "kappa": kappa,
            "alpha": alpha,
            "beta": beta,
            "p": 0.8,
            "allow_self_connections": False,
        }
        test_model = DiscreteTimeNetworkHawkesModelGammaMixtureSBM(
            K=K, dt=dt, dt_max=dt_max, basis=test_basis, network_hypers=network_hypers
        )
        # Initialize with the standard model parameters
        if standard_model is not None:
            test_model.initialize_with_standard_model(standard_model)
        #
        # plt.ion()
        # im = plot_network(test_model.weight_model.A, test_model.weight_model.W, vmax=0.03)
        # plt.pause(0.001)

        # TODO: Add the data in minibatches
        minibatchsize = 3000
        test_model.add_data(S)

        # Stochastic variational inference
        samples = []
        delay = 10.0
        forgetting_rate = 0.5
        stepsize = (np.arange(N_iters) + delay) ** (-forgetting_rate)
        timestamps = []
        for itr in xrange(N_iters):
            if true_network is not None:
                # W_score = test_model.weight_model.expected_A()
                W_score = test_model.weight_model.expected_W()
                print "AUC: ", roc_auc_score(true_network.ravel(), W_score.ravel())

            print "SVI Iter: ", itr, "\tStepsize: ", stepsize[itr]
            test_model.sgd_step(minibatchsize=minibatchsize, stepsize=stepsize[itr])
            test_model.resample_from_mf()
            samples.append(test_model.copy_sample())
            timestamps.append(time.clock())
            #
            # if itr % 1 == 0:
            #     plt.figure(1)
            #     im.set_data(test_model.weight_model.expected_W())
            #     plt.pause(0.001)

            # Save this sample
            with open(output_path + ".svi.itr%04d.pkl" % itr, "w") as f:
                cPickle.dump(samples[-1], f, protocol=-1)

        # Save the Gibbs samples
        with gzip.open(output_path + ".svi.pkl.gz", "w") as f:
            print "Saving SVI samples to ", (output_path + ".svi.pkl.gz")
            cPickle.dump((samples, timestamps), f, protocol=-1)

    return samples, timestamps
def fit_network_hawkes_svi(S,
                           K,
                           C,
                           dt,
                           dt_max,
                           output_path,
                           standard_model=None,
                           N_iters=500,
                           true_network=None):

    # Check for existing Gibbs results
    if os.path.exists(output_path + ".svi.pkl.gz"):
        with gzip.open(output_path + ".svi.pkl.gz", 'r') as f:
            print "Loading SVI results from ", (output_path + ".svi.pkl.gz")
            (samples, timestamps) = cPickle.load(f)
    elif os.path.exists(output_path + ".svi.itr%04d.pkl" % (N_iters - 1)):
        with open(output_path + ".svi.itr%04d.pkl" % (N_iters - 1), 'r') as f:
            print "Loading SVI results from ", (output_path +
                                                ".svi.itr%04d.pkl" %
                                                (N_iters - 1))
            sample = cPickle.load(f)
            samples = [sample]
            timestamps = None
            # (samples, timestamps) = cPickle.load(f)

    else:
        print "Fitting the data with a network Hawkes model using SVI"

        # Make a new model for inference
        test_basis = IdentityBasis(dt, dt_max, allow_instantaneous=True)
        E_W = 0.01
        kappa = 10.
        E_v = kappa / E_W
        alpha = 10.
        beta = alpha / E_v
        # network_hypers = {'C': 2,
        #                   'kappa': kappa, 'alpha': alpha, 'beta': beta,
        #                   'p': 0.1, 'tau1': 1.0, 'tau0': 1.0,
        #                   'allow_self_connections': False}
        # test_model = DiscreteTimeNetworkHawkesModelGammaMixture(K=K, dt=dt, dt_max=dt_max,
        #                                                         basis=test_basis,
        #                                                         network_hypers=network_hypers)

        network_hypers = {
            'C': 2,
            'kappa': kappa,
            'alpha': alpha,
            'beta': beta,
            'p': 0.8,
            'allow_self_connections': False
        }
        test_model = DiscreteTimeNetworkHawkesModelGammaMixtureSBM(
            K=K,
            dt=dt,
            dt_max=dt_max,
            basis=test_basis,
            network_hypers=network_hypers)
        # Initialize with the standard model parameters
        if standard_model is not None:
            test_model.initialize_with_standard_model(standard_model)
        #
        # plt.ion()
        # im = plot_network(test_model.weight_model.A, test_model.weight_model.W, vmax=0.03)
        # plt.pause(0.001)

        # TODO: Add the data in minibatches
        minibatchsize = 3000
        test_model.add_data(S)

        # Stochastic variational inference
        samples = []
        delay = 10.0
        forgetting_rate = 0.5
        stepsize = (np.arange(N_iters) + delay)**(-forgetting_rate)
        timestamps = []
        for itr in xrange(N_iters):
            if true_network is not None:
                # W_score = test_model.weight_model.expected_A()
                W_score = test_model.weight_model.expected_W()
                print "AUC: ", roc_auc_score(true_network.ravel(),
                                             W_score.ravel())

            print "SVI Iter: ", itr, "\tStepsize: ", stepsize[itr]
            test_model.sgd_step(minibatchsize=minibatchsize,
                                stepsize=stepsize[itr])
            test_model.resample_from_mf()
            samples.append(test_model.copy_sample())
            timestamps.append(time.clock())
            #
            # if itr % 1 == 0:
            #     plt.figure(1)
            #     im.set_data(test_model.weight_model.expected_W())
            #     plt.pause(0.001)

            # Save this sample
            with open(output_path + ".svi.itr%04d.pkl" % itr, 'w') as f:
                cPickle.dump(samples[-1], f, protocol=-1)

        # Save the Gibbs samples
        with gzip.open(output_path + ".svi.pkl.gz", 'w') as f:
            print "Saving SVI samples to ", (output_path + ".svi.pkl.gz")
            cPickle.dump((samples, timestamps), f, protocol=-1)

    return samples, timestamps
def fit_network_hawkes_svi(S,
                           K,
                           C,
                           dt,
                           dt_max,
                           output_path,
                           standard_model=None,
                           N_iters=100,
                           true_network=None):
    """
    From Scott Linderman's experiments in https://github.com/slinderman/pyhawkes/tree/master/experiments
    """
    # Check for existing Gibbs results
    if os.path.exists(output_path + ".svi.pkl.gz"):
        with gzip.open(output_path + ".svi.pkl.gz", 'r') as f:
            print("Loading SVI results from ", (output_path + ".svi.pkl.gz"))
            (samples, timestamps) = pickle.load(f)
    elif os.path.exists(output_path + ".svi.itr%04d.pkl" % (N_iters - 1)):
        with open(output_path + ".svi.itr%04d.pkl" % (N_iters - 1), 'r') as f:
            print("Loading SVI results from ",
                  (output_path + ".svi.itr%04d.pkl" % (N_iters - 1)))
            sample = pickle.load(f)
            samples = [sample]
            timestamps = None
            # (samples, timestamps) = cPickle.load(f)

    else:
        print("Fitting the data with a network Hawkes model using SVI")

        #------------- Make a new model for inference
        test_basis = IdentityBasis(dt, dt_max, allow_instantaneous=True)
        E_W = 0.01
        kappa = 10.
        E_v = kappa / E_W
        alpha = 10.
        beta = alpha / E_v
        network_hypers = {
            'C': 2,
            'kappa': kappa,
            'alpha': alpha,
            'beta': beta,
            'p': 0.8,
            'allow_self_connections': False
        }
        test_model = DiscreteTimeNetworkHawkesModelGammaMixtureSBM(
            K=K,
            dt=dt,
            dt_max=dt_max,
            basis=test_basis,
            network_hypers=network_hypers)
        #------------- Initialize with the standard model parameters
        if standard_model is not None:
            test_model.initialize_with_standard_model(standard_model)
        minibatchsize = 3000
        test_model.add_data(S)

        #------------- Stochastic variational inference learning with default algorithm hyperparameters
        samples = []
        delay = 10.0
        forgetting_rate = 0.5
        stepsize = (np.arange(N_iters) + delay)**(-forgetting_rate)
        timestamps = []
        for itr in range(N_iters):

            print("SVI Iter: ", itr, "\tStepsize: ", stepsize[itr])
            test_model.sgd_step(minibatchsize=minibatchsize,
                                stepsize=stepsize[itr])
            test_model.resample_from_mf()
            samples.append(test_model.copy_sample())
            timestamps.append(time.clock())

            with open(output_path + ".svi.itr%04d.pkl" % itr, 'w') as f:
                pickle.dump(samples[-1], f, protocol=-1)

        with gzip.open(output_path + ".svi.pkl.gz", 'w') as f:
            print("Saving SVI samples to ", (output_path + ".svi.pkl.gz"))
            pickle.dump((samples, timestamps), f, protocol=-1)

    return samples, timestamps