class MCMC:
    def __init__(self, samples, topology, use_langevin_gradients, lr,
                 batch_size):
        self.samples = samples
        self.topology = topology
        self.lr = lr
        self.batch_size = batch_size
        self.use_langevin_gradients = use_langevin_gradients
        self.l_prob = 0.5
        self.cnn = Model(lr, batch_size, 'CNN')
        self.train_data = data_load(data='train')
        self.test_data = data_load(data='test')
        self.step_size = step_size
        self.learn_rate = lr

    def likelihood_func(self, data, tau_sq=1, w=None):
        flag = False
        for i, dat in enumerate(data, 0):
            inputs, labels = dat
            if (flag):
                y = tf.concat((y, labels), axis=0)
            else:
                y = labels
                flag = True
        if w is not None:
            fx = self.cnn.evaluate_proposal(data, w)
        else:
            fx = self.cnn.evaluate_proposal(data)
        # rmse = self.rmse(fx,y)
        # print("proposal calculated")
        rmse = copy.deepcopy(self.cnn.los) / len(data)
        # print("RMSE: ", rmse)
        # print(self.cnn.trainable_weights)
        loss = np.sum(-0.5 * np.log(2 * math.pi * tau_sq) -
                      0.5 * np.square(y - fx / tau_sq))
        return [np.sum(loss), fx, rmse]  # / self.adapttemp

    def prior_likelihood(self, sigma_squared, w_list):
        # w_list = self.cnn.getparameters(self.cnn.trainable_weights)
        part1 = -1 * ((len(w_list)) / 2) * np.log(sigma_squared)
        part2 = 1 / (2 * sigma_squared) * (sum(np.square(w_list)))
        log_loss = part1 - part2
        return log_loss

    def rmse(self, pred, actual):
        error = np.subtract(pred, actual)
        sqerror = np.sum(np.square(error)) / actual.shape[0]
        return np.sqrt(sqerror)

    def sampler(self):
        print("chain running")
        samples = self.samples
        # self.cnn = self.cnn

        # Random Initialisation of weights
        w = self.cnn.trainable_weights
        w_size = len(self.cnn.getparameters(w))
        step_w = self.step_size

        rmse_train = np.zeros(samples)
        rmse_test = np.zeros(samples)
        # acc_train = np.zeros(samples)
        # acc_test = np.zeros(samples)
        weight_array = np.zeros(samples)
        weight_array1 = np.zeros(samples)
        weight_array2 = np.zeros(samples)
        weight_array3 = np.zeros(samples)
        weight_array4 = np.zeros(samples)
        likelihood_array = np.zeros(samples)
        sum_value_array = np.zeros(samples)

        learn_rate = self.learn_rate
        flag = False
        for i, sample in enumerate(self.train_data, 0):
            _, label = sample
            if (flag):
                y_train = tf.concat((y_train, label), axis=0)
            else:
                flag = True
                y_train = label

        pred_train = self.cnn.evaluate_proposal(self.train_data)

        # flag = False
        # for i in range(len(pred)):
        #     label = pred[i]
        #     if(flag):
        #       pred_train = torch.cat((pred_train, label), dim = 0)
        #     else:
        #       flag = True
        #       pred_train = label

        step_eta = 0.2

        eta = np.log(np.var(pred_train - y_train))
        tau_pro = np.sum(np.exp(eta))
        # print(tau_pro)

        w_proposal = np.random.randn(w_size)
        # w_proposal =self.cnn.dictfromlist(w_proposal)
        train = self.train_data
        test = self.test_data

        sigma_squared = 25
        prior_current = self.prior_likelihood(
            sigma_squared,
            self.cnn.getparameters(w))  # takes care of the gradients

        # Evaluate Likelihoods

        # print("calculating prob")
        [likelihood, pred_train,
         rmsetrain] = self.likelihood_func(train, tau_pro)
        # print("prior calculated")
        # print("Hi")
        [_, pred_test, rmsetest] = self.likelihood_func(test, tau_pro)

        # print("Bye")

        # Beginning sampling using MCMC

        # y_test = torch.zeros((len(test), self.batch_size))
        # for i, dat in enumerate(test, 0):
        #     inputs, labels = dat
        #     y_test[i] = copy.deepcopy(labels)
        # y_train = torch.zeros((len(train), self.batch_size))
        # for i, dat in enumerate(train, 0):
        #     inputs, labels = dat
        #     y_train[i] = copy.deepcopy(labels)

        num_accepted = 0  # TODO: save this
        langevin_count = 0
        lcount_acc = 0
        ncount_acc = 0  # TODO: save this

        # if(load):
        #   [langevin_count, num_accepted] = np.loadtxt(
        #      self.path+'/parameters/langevin_count_'+str(self.temperature) + '.txt')
        # TODO: remember to add number of samples from last run
        # PT in canonical form with adaptive temp will work till assigned limit
        pt_samples = (500) * 0.6
        init_count = 0

        rmse_train[0] = np.sqrt(rmsetrain)
        rmse_test[0] = np.sqrt(rmsetest)

        weight_array[0] = 0
        weight_array1[0] = 0
        weight_array2[0] = 0
        weight_array3[0] = 0
        weight_array4[0] = 0
        likelihood_array[0] = 0

        sum_value_array[0] = 0
        print("beginnning sampling")
        import time
        start = time.time()

        for i in range(
                samples
        ):  # Begin sampling --------------------------------------------------------------------------
            # print("sampling", i)
            ratio = ((samples - i) / (samples * 1.0))  # ! why this?
            # TODO: remember to add number of samples from last run in i (i+2400<pt_samples)
            # if (i+samples_run) < pt_samples:
            #    self.adapttemp = self.temperature  # T1=T/log(k+1);
            # if i == pt_samples and init_count == 0:  # Move to canonical MCMC
            # self.adapttemp = 1
            [likelihood, pred_train,
             rmsetrain] = self.likelihood_func(train, tau_pro, w)
            [_, pred_test, rmsetest] = self.likelihood_func(test, tau_pro, w)
            init_count = 1

            lx = np.random.uniform(0, 1, 1)
            old_w = self.cnn.trainable_weights

            l = 0

            if ((self.use_langevin_gradients is True) and
                (lx <
                 self.l_prob)):  # (langevin_count < self.langevin_step) or
                # print("Length of Train ", len(train))
                w_gd = self.cnn.langevin_gradient(train)
                w_proposal = self.cnn.addnoiseandcopy(0, step_w)
                w_prop_gd = self.cnn.langevin_gradient(train)
                wc_delta = (self.cnn.getparameters(w) -
                            self.cnn.getparameters(w_prop_gd))
                wp_delta = (self.cnn.getparameters(w_proposal) -
                            self.cnn.getparameters(w_gd))
                sigma_sq = step_w
                first = -0.5 * np.sum(wc_delta * wc_delta) / sigma_sq
                second = -0.5 * np.sum(wp_delta * wp_delta) / sigma_sq
                diff_prop = first - second
                diff_prop = diff_prop  # / self.adapttemp
                langevin_count = langevin_count + 1
                l = 1
            else:
                diff_prop = 0
                w_proposal = self.cnn.addnoiseandcopy(0, step_w)
                l = 0

            eta_pro = eta + np.random.normal(0, step_eta, 1)
            tau_pro = math.exp(eta_pro)

            [likelihood_proposal, pred_train,
             rmsetrain] = self.likelihood_func(train, tau_pro)
            [likelihood_ignore, pred_test,
             rmsetest] = self.likelihood_func(test, tau_pro)

            prior_prop = self.prior_likelihood(
                sigma_squared, self.cnn.getparameters(w_proposal))
            diff_likelihood = likelihood_proposal - likelihood
            diff_prior = prior_prop - prior_current

            try:
                mh_prob = min(
                    1, math.exp(diff_likelihood + diff_prior + diff_prop))
            except OverflowError as e:
                mh_prob = 1

            sum_value = diff_likelihood + diff_prior + diff_prop
            sum_value_array[i] = sum_value
            u = (random.uniform(0, 1))
            # print(mh_prob, 'mh_prob')
            if u < mh_prob:
                num_accepted = num_accepted + 1
                if (l == 1):
                    lcount_acc += 1
                else:
                    ncount_acc += 1
                likelihood = likelihood_proposal
                prior_current = prior_prop

                eta = eta_pro

                w = copy.deepcopy(
                    w_proposal)  # self.cnn.getparameters(w_proposal)
                # acc_train1 = self.accuracy(train)
                # acc_test1 = self.accuracy(test)
                # if(l==1):
                # print("Langevin gradient proposal accepted")
                print(i + samples_run, np.sqrt(rmsetrain), np.sqrt(rmsetest),
                      'Accepted')
                final_preds = self.cnn.call(test_X.reshape(test_len - 1, 5, 1))
                for j in range(n_steps_out):
                    a = final_preds[:, j]
                    b = test_Y.reshape(test_len - 1, n_steps_out)[:, j]
                    steps_rmse_val[(i * 10):(i * 10) + 10] = (self.rmse(a, b))
                    # print( steps_rmse_val[(i*10):(i*10)+10])
                rmse_train[i] = np.sqrt(rmsetrain)
                rmse_test[i] = np.sqrt(rmsetest)
                # acc_train[i,] = acc_train1
                # acc_test[i,] = acc_test1

            else:
                w = old_w
                # print(w)
                self.cnn.loadparameters(w)
                # acc_train1 = self.accuracy(train)
                # acc_test1 = self.accuracy(test)
                print(i + samples_run, np.sqrt(rmsetrain), np.sqrt(rmsetest),
                      'Rejected')
                # implying that first proposal(i=0) will never be rejected?
                rmse_train[i, ] = rmse_train[i - 1, ]
                rmse_test[i, ] = rmse_test[i - 1, ]
                # acc_train[i,] = acc_train[i - 1,]
                # acc_test[i,] = acc_test[i - 1,]

            ll = self.cnn.getparameters()
            print(ll.shape)
            weight_array[i] = ll[10]
            weight_array1[i] = ll[500]
            weight_array2[i] = ll[1000]
            likelihood_array[i] = likelihood
            # weight_array3[i] = ll[4000]
            # weight_array4[i] = ll[8000]

        end = time.time()
        print("\n\nTotal time taken for Sampling : ", (end - start))
        print((num_accepted * 100 / (samples * 1.0)), '% was Accepted')
        acceptance = num_accepted * 100 / (samples * 1.0)

        print((langevin_count * 100 / (samples * 1.0)), '% was Langevin')
        print(lcount_acc, '% was number of Langevin proposals accepted')
        print(ncount_acc, '% was number of Random Walk proposals accepted')
        final_preds = self.cnn.call(test_X.reshape(test_len - 1, 5, 1))
        # print(self.cnn.call(test_y.reshape(test_len-1,5,1))
        print("Shape is :", final_preds.shape)
        # final_preds = final_preds.detach().numpy()
        # test_Y = test_Y.reshape(test_len-1,10)
        step_rmse = np.zeros(10)

        for j in range(10):
            plt.figure()
            plt.plot(test_Y.reshape(test_len - 1, n_steps_out)[:, j],
                     label='actual')
            plt.plot(final_preds[:, j], label='predicted')
            a = final_preds[:, j]
            b = test_Y.reshape(test_len - 1, n_steps_out)[:, j]
            print("RMSE for Step ", j + 1, ": ", self.rmse(a, b))
            step_rmse[j] = self.rmse(a, b)
            plt.ylabel('Predicted/Actual')
            plt.xlabel('Time (samples)')
            plt.title('Actual vs Predicted')
            # txt = "RMSE for Step "+str(j+1)+": "+str(self.rmse(a,b))
            # plt.text(5.0, -1, txt)
            plt.legend()
            plt.savefig(str(data_set) + '_results' + str(learnr) + '_' +
                        str(step_size) + '_' + str(numSamples) + '/pred_Step' +
                        str(j + 1) + '.png',
                        dpi=300)
            # plt.show()
            plt.close()

        result_step = [
            str(acceptance),
            str(step_rmse),
            str(np.mean(step_rmse))
        ]
        with open(str(data_set) + '_results' + str(learnr) + '_' +
                  str(step_size) + '_' + str(numSamples) + '/step_results.txt',
                  'w',
                  encoding='utf-8') as f:
            f.write('\n'.join(result_step))
        # np.savetxt(str(data_set)+'_results'+str(learnr)+'_'+str(step_size)+ '/acceptance_stepwisermse_meanstepwise.txt',np.asarray([acceptance, step_rmse,np.mean(step_rmse)]))
        # np.savetxt(str(data_set)+'_results'+str(learnr)+'_'+str(step_size)+ '/stepwise_rmse.txt',np.asarray([step_rmse]))
        # np.savetxt(str(data_set)+'_results'+str(learnr)+'_'+str(step_size)+ '/mean_stepwise_rmse.txt',np.asarray(np.mean(step_rmse)))
        print("Mean value of RMSE over 10 steps :", str(np.mean(step_rmse)))

        return rmse_train, rmse_test, sum_value_array, weight_array, weight_array1, weight_array2, likelihood_array  # acc_train, acc_test,