Exemplo n.º 1
0
class Regressor:
    """Neural network class for the Boston pricing house problem.

    Attributes:
        model: Keras Sequential model
    """
    def __init__(self, dim: int):
        self.model = Sequential()
        self.model.add(Dense(144, input_dim=dim, activation="relu"))
        self.model.add(Dense(72, activation="relu"))
        self.model.add(Dense(18, activation="relu"))
        self.model.add(Dense(1, activation="linear"))

        self.model.compile(optimizer="adam", loss="mean_squared_error")

    def train_n_epochs(self, n_epochs: int, x_train: pd.DataFrame,
                       y_train: pd.DataFrame) -> None:
        """Training function for the built in model.

        Args:
            n_epochs (int): Number of epochs to be trained.
            x_train (~pd.dataframe): Features dataset for training.
            y_train(~pd.dataframe): Labels for training.
        """

        self.model.fit(x_train, y_train, epochs=n_epochs, verbose=0)

    def evaluate_on_test(self, x_test: pd.DataFrame,
                         y_test: pd.DataFrame) -> Tuple[float, float]:
        """Evaluating on testset.

        Args:
             x_test (dataframe): Feature set for evaluation.
             y_test (dataframe): Dependent variable for evaluation.

        Returns:
            test_loss: Value of the testing loss.
            r_squared: Value of R-squared,
                to be shown as 'accuracy' metric to the Coordinator
        """

        y_pred: np.ndarray = self.model.predict(x_test)
        r_squared: float = r2_score(y_test, y_pred)
        test_loss: float = self.model.evaluate(x_test, y_test)
        return test_loss, r_squared

    def get_shapes(self) -> List[Tuple[int, ...]]:
        return [weight.shape for weight in self.model.get_weights()]

    def get_weights(self) -> np.ndarray:
        return np.concatenate(self.model.get_weights(), axis=None)

    def set_weights(self, weights: np.ndarray) -> None:
        shapes = self.get_shapes()
        # expand the flat weights
        indices: np.ndarray = np.cumsum([np.prod(shape) for shape in shapes])
        tensorflow_weights: List[np.ndarray] = np.split(
            weights, indices_or_sections=indices)
        tensorflow_weights = [
            np.reshape(weight, newshape=shape)
            for weight, shape in zip(tensorflow_weights, shapes)
        ]

        # apply the weights to the tensorflow model
        self.model.set_weights(tensorflow_weights)
Exemplo n.º 2
0
# This is a sample Python script.

# Press ⌃R to execute it or replace it with your code.
# Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings.

import tensorflow as tf
import numpy as np
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

model = Sequential([Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')

xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)

model.fit(xs, ys, epochs=500)

print(model.predict([10.0]))
print("Here is what i have learned: {}".format(model.get_weights()))
    y_p = prices
    # 缩小数据方便后续计算
    features = features/100
    prices = prices/100


# 定义模型
model = Sequential()
model.add(Dense(1, kernel_initializer='random_normal', bias_initializer='zeros'))
# 编译模型
model.compile(optimizer=optimizers.SGD(lr=learning_rate), loss='mse', metrics=['acc'])
# 训练模型
batch_size = 10
for epoch in range(num_epochs):
    history = model.fit(x=features, y=prices, batch_size=batch_size, shuffle=True)
    print('epoch ', epoch+1, ',loss = ', history.history['loss'])
print("w, b=", model.get_weights())
[w, b] = model.get_weights()


# 画图
# 青色点为数据散点图
# 蓝色方块为预期直线
# 红色圆圈为预测直线
x = np.arange(0, 200, 2)
plt.scatter(x_f, y_p, 3, hold=True, c='c')
plt.scatter(x, 6.7*x-24.42, 13, hold=True, c='b', marker='s')
plt.scatter(x, w*x+b*100, 7, c='r', marker='o')
plt.show()

Exemplo n.º 4
0
def trainningNetwork(bagOfWords, y):
    tiempo_i = time.time()

    Errores = np.ones(10)
    # Sens = np.zeros(10)
    # Espec = np.zeros(10)
    Precision = np.zeros(10)
    Recall = np.zeros(10)
    F1score = np.zeros(10)
    j = 0
    kf = KFold(n_splits=10, shuffle=True)

    for train_index, test_index in kf.split(bagOfWords):
        #print("TRAIN:", train_index, "TEST:", test_index)
        X_train, X_test = bagOfWords[train_index], bagOfWords[test_index]
        y_train, y_test = y[train_index], y[test_index]

        #Instanciamos el modelo MLP
        model = Sequential()

        model.add(
            Dense(units=15, activation='relu', input_dim=bagOfWords.shape[1]))

        #Dropout
        model.add(Dropout(0.25))

        model.add(Dense(units=50, activation='relu'))
        model.add(Dense(units=20, activation='relu'))
        model.add(Dense(units=50, activation='relu'))
        model.add(Dense(units=40, activation='relu'))
        model.add(Dense(units=40, activation='relu'))
        model.add(Dense(units=20, activation='relu'))
        model.add(Dense(1))
        model.add(Activation('sigmoid'))

        # Model config
        model.get_config()

        # List all weight tensors
        model.get_weights()

        model.compile(loss='binary_crossentropy',
                      optimizer='adam',
                      metrics=['accuracy'])

        #Train process
        model.fit(X_train, y_train, epochs=30)
        # dump(model, 'mlpModel.joblib')
        # Test
        ypred = model.predict(X_test)
        y_pred = []
        for yp, yt in zip(ypred, y_test):

            if yp <= 0.5:
                yp = 0
            else:
                yp = 1
            y_pred.append(yp)
            #print(yp, '\t', yt)

        y_pred = np.asarray(y_pred)

        Errores[j] = classification_error(y_pred, y_test)
        #print('Error en la iteración: ', Errores[j])

        precision, recall, f1score = error_measures(y_pred, y_test)
        # Sens[j] = sens
        # Espec[j] = esp
        Precision[j] = precision
        Recall[j] = recall
        F1score[j] = f1score
        j += 1

    return model, Errores, Precision, Recall, F1score, time.time() - tiempo_i
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Flatten, Dense, Dropout
import numpy as np

# 模型定义
model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(units=200, activation='tanh'),
    Dropout(0.4),
    Dense(units=100, activation='tanh'),
    Dropout(0.4),
    Dense(units=10, activation='softmax')
])

# 保存模型参数
model.save_weights('my_model/model_weights')

# 获取模型参数
weights = model.get_weights()
# 把list转变成array
weights = np.array(weights)

# 循环每一层权值
# enumerate相当于循环计数器,记录当前循环次数
# weights保存的数据可以对照print输出查看
for i, w in enumerate(weights):
    if i % 2 == 0:
        print('{}:w_shape:{}'.format(int(i / 2 + 1), w.shape))
    else:
        print('{}:b_shape:{}'.format(int(i / 2 + 0.5), w.shape))
class Agent:
    def __init__(self, env, optimizer, batch_size):
        # general info
        self.state_size = env.observation_space.shape[
            0]  # number of factors in the state; e.g: velocity, position, etc
        self.action_size = env.action_space.n
        self.optimizer = optimizer
        self.batch_size = batch_size

        # allow large replay exp space
        self.replay_exp = deque(maxlen=1000000)

        self.gamma = 0.99
        self.epsilon = 1.0  # initialize with high exploration, which will decay later

        # Build Policy Network
        self.brain_policy = Sequential()
        self.brain_policy.add(
            Dense(128, input_dim=self.state_size, activation="relu"))
        self.brain_policy.add(Dense(128, activation="relu"))
        self.brain_policy.add(Dense(self.action_size, activation="linear"))
        self.brain_policy.compile(loss="mse", optimizer=self.optimizer)

        # Build Target Network
        self.brain_target = Sequential()
        self.brain_target.add(
            Dense(128, input_dim=self.state_size, activation="relu"))
        self.brain_target.add(Dense(128, activation="relu"))
        self.brain_target.add(Dense(self.action_size, activation="linear"))
        self.brain_target.compile(loss="mse", optimizer=self.optimizer)

        self.update_brain_target()

    # add new experience to the replay exp
    def memorize_exp(self, state, action, reward, next_state, done):
        self.replay_exp.append((state, action, reward, next_state, done))

    """
    # agent's brain
    def build_model(self):
        # a NN with 2 fully connected hidden layers
        model = Sequential()
        model.add(Dense(128, input_dim = self.state_size, activation = "relu"))
        model.add(Dense(128 , activation = "relu"))
        model.add(Dense(self.action_size, activation = "linear"))
        model.compile(loss = "mse", optimizer = self.optimizer)
        
        return model
    """

    def update_brain_target(self):
        return self.brain_target.set_weights(self.brain_policy.get_weights())

    def choose_action(self, state):
        if np.random.uniform(0.0, 1.0) < self.epsilon:  # exploration
            action = np.random.choice(self.action_size)
        else:
            state = np.reshape(state, [1, state_size])
            qhat = self.brain_policy.predict(
                state)  # output Q(s,a) for all a of current state
            action = np.argmax(
                qhat[0]
            )  # because the output is m * n, so we need to consider the dimension [0]

        return action

    # update params in NN
    def learn(self):
        """
        sample = random.choices(self.replay_exp, k = min(len(self.replay_exp), self.batch_size))
        
        
        states, actions, rewards, next_states, dones = map(list, zip(sample))
        
        # add exp to replay exp
        qhats_next = self.brain_target(next_states)
        
        # set all value actions of terminal state to 0
        qhats_next[dones] = np.zeros((self.action_size))
        
        q_targets = rewards + self.gamma * np.max(qhats_next, axis=1) # update greedily
        
        self.brain.update_nn(self.sess, states, actions, q_targets)
        
        """

        # take a mini-batch from replay experience
        cur_batch_size = min(len(self.replay_exp), self.batch_size)
        mini_batch = random.sample(self.replay_exp, cur_batch_size)

        # batch data
        sample_states = np.ndarray(
            shape=(cur_batch_size,
                   self.state_size))  # replace 128 with cur_batch_size
        sample_actions = np.ndarray(shape=(cur_batch_size, 1))
        sample_rewards = np.ndarray(shape=(cur_batch_size, 1))
        sample_next_states = np.ndarray(shape=(cur_batch_size,
                                               self.state_size))
        sample_dones = np.ndarray(shape=(cur_batch_size, 1))

        temp = 0
        for exp in mini_batch:
            sample_states[temp] = exp[0]
            sample_actions[temp] = exp[1]
            sample_rewards[temp] = exp[2]
            sample_next_states[temp] = exp[3]
            sample_dones[temp] = exp[4]
            temp += 1

        sample_qhat_next = self.brain_target.predict(sample_next_states)

        # set all Q values terminal states to 0
        sample_qhat_next = sample_qhat_next * (
            np.ones(shape=sample_dones.shape) - sample_dones)
        # choose max action for each state
        sample_qhat_next = np.max(sample_qhat_next, axis=1)

        sample_qhat = self.brain_policy.predict(sample_states)

        for i in range(cur_batch_size):
            a = sample_actions[i, 0]
            sample_qhat[
                i,
                int(a)] = sample_rewards[i] + self.gamma * sample_qhat_next[i]

        q_target = sample_qhat

        self.brain_policy.fit(sample_states, q_target, epochs=1, verbose=0)
        """
Exemplo n.º 7
0
class DDQNAgent:
    def __init__(self, env, optimizer, gamma, batch_size):
        # general info
        self.state_size = env.observation_space.shape[
            0]  # number of factors in the state; e.g: velocity, position, etc
        _, self.action_size = quantize(None)
        self.batch_size = batch_size

        # allow large replay exp space
        # self.replay_exp = deque(maxlen=BUFFER_SIZE)
        self.replay_exp = ExperienceReplay(type=REPLAY_TYPE)

        self.gamma = gamma
        self.epsilon = 1.0  # initialize with high exploration, which will decay later

        # Build Policy Network
        self.brain_policy = Sequential()
        self.brain_policy.add(
            Dense(256, input_dim=self.state_size, activation="relu"))
        self.brain_policy.add(Dense(256, activation="relu"))
        self.brain_policy.add(Dense(64, activation="relu"))
        self.brain_policy.add(Dense(self.action_size, activation="linear"))
        self.brain_policy.compile(loss="mse", optimizer=optimizer)

        # Build Target Network
        self.brain_target = Sequential()
        self.brain_target.add(
            Dense(256, input_dim=self.state_size, activation="relu"))
        self.brain_target.add(Dense(256, activation="relu"))
        self.brain_target.add(Dense(64, activation="relu"))
        self.brain_target.add(Dense(self.action_size, activation="linear"))
        self.brain_target.compile(loss="mse", optimizer=optimizer)

        self.update_brain_target()

    # # add new experience to the replay exp
    # def memorize_exp(self, state, action, reward, next_state, done):
    #     self.replay_exp.append((state, action, reward, next_state, done))

    def update_brain_target(self):
        return self.brain_target.set_weights(self.brain_policy.get_weights())

    def choose_action(self, state):
        if self._should_do_exploration():
            action = np.random.choice(self.action_size)
        else:
            state = np.reshape(state, [1, self.state_size])
            qhat = self.brain_policy.predict(
                state)  # output Q(s,a) for all a of current state
            action = np.argmax(
                qhat[0]
            )  # because the output is m * n, so we need to consider the dimension [0]

        return action

    def learn(self, sample=None):
        # take a mini-batch from replay experience
        if sample is None:
            if self.replay_exp.is_prioritized():
                cur_batch_size = min(self.replay_exp.size, self.batch_size)
                mini_batch = self.replay_exp.replay_exp.sample(cur_batch_size)
            else:
                cur_batch_size = min(self.replay_exp.size, self.batch_size)
                mini_batch = random.sample(self.replay_exp.replay_exp,
                                           cur_batch_size)
        else:
            cur_batch_size = 1
            mini_batch = [(0, sample)]

        # batch data
        sample_states = np.ndarray(shape=(cur_batch_size, self.state_size))
        sample_actions = np.ndarray(shape=(cur_batch_size, 2))
        sample_rewards = np.ndarray(shape=(cur_batch_size, 1))
        sample_next_states = np.ndarray(shape=(cur_batch_size,
                                               self.state_size))
        sample_dones = np.ndarray(shape=(cur_batch_size, 1))

        for index, exp in enumerate(mini_batch):
            if self.replay_exp.is_prioritized():
                sample_states[index] = exp[1][0]
                sample_actions[index] = exp[1][1]
                sample_rewards[index] = exp[1][2]
                sample_next_states[index] = exp[1][3]
                sample_dones[index] = exp[1][4]
            else:
                sample_states[index] = exp[0]
                sample_actions[index] = exp[1]
                sample_rewards[index] = exp[2]
                sample_next_states[index] = exp[3]
                sample_dones[index] = exp[4]

        sample_qhat_next = self.brain_target.predict(sample_next_states)

        # set all Q values terminal states to 0
        sample_qhat_next = sample_qhat_next * (
            np.ones(shape=sample_dones.shape) - sample_dones)

        # choose max action for each state
        sample_qhat_next = np.max(sample_qhat_next, axis=1)

        sample_qhat = self.brain_policy.predict(sample_states)

        if self.replay_exp.is_prioritized():
            errors = np.zeros(cur_batch_size)

        for i in range(cur_batch_size):
            a = tuple(sample_actions[i])
            sample_qhat[i, ACTIONS_TO_IDX[a]] = sample_rewards[
                i] + self.gamma * sample_qhat_next[i]

            if self.replay_exp.is_prioritized():
                old_value = sample_qhat[i, ACTIONS_TO_IDX[tuple(a)]]
                errors[i] = abs(old_value - sample_qhat[i, ACTIONS_TO_IDX[a]])

        q_target = sample_qhat
        # self.brain_policy.fit(sample_states, q_target, epochs=1, verbose=0)
        if sample is None:
            if self.replay_exp.is_prioritized():
                self.replay_exp.memorize_exp(mini_batch, {
                    'e': errors,
                    'is_update': True
                })

            self.brain_policy.fit(sample_states, q_target, epochs=1, verbose=0)
        else:
            self.replay_exp.memorize_exp(sample, {
                'e': errors[0],
                'is_update': False
            })

    def _should_do_exploration(self):
        return np.random.uniform(0.0, 1.0) < self.epsilon
Exemplo n.º 8
0

# 保存模型结构
json_config = model.to_json()
print(json_config)


# In[3]:


import json
# 保存json模型结构文件
with open('model.json','w') as m:
    json.dump(json_config,m)


# In[14]:


import numpy as np
a = np.array(model.get_weights())
for i in a:
    print(i.shape)


# In[ ]:


for i in 

Exemplo n.º 9
0
class GAN(object):
    def __init__(self, steps=1, lr=0.00001, decay=0.001):

        # Models
        self.D = None
        self.S = None
        self.G = None

        self.GE = None
        self.SE = None

        self.DM = None
        self.AM = None

        # Config
        self.LR = lr
        self.steps = steps
        self.beta = 0.999

        # Init Models
        self.discriminator()
        self.generator()

        self.GMO = Adam(lr=self.LR, beta_1=0, beta_2=0.999)
        self.DMO = Adam(lr=self.LR, beta_1=0, beta_2=0.999)

        self.GE = clone_model(self.G)
        self.GE.set_weights(self.G.get_weights())

        self.SE = clone_model(self.S)
        self.SE.set_weights(self.S.get_weights())

    def discriminator(self):

        if self.D:
            return self.D

        inp = Input(shape=[im_size, im_size, 3])

        x = d_block(inp, 1 * cha)  # 128

        x = d_block(x, 2 * cha)  # 64

        x = d_block(x, 4 * cha)  # 32

        x = d_block(x, 6 * cha)  # 16

        x = d_block(x, 8 * cha)  # 8

        x = d_block(x, 16 * cha)  # 4

        x = d_block(x, 32 * cha, p=False)  # 4

        x = Flatten()(x)

        x = Dense(1, kernel_initializer="he_uniform")(x)

        self.D = Model(inputs=inp, outputs=x)

        return self.D

    def generator(self):

        if self.G:
            return self.G

        # === Style Mapping ===

        self.S = Sequential()

        self.S.add(Dense(512, input_shape=[latent_size]))
        self.S.add(LeakyReLU(0.2))
        self.S.add(Dense(512))
        self.S.add(LeakyReLU(0.2))
        self.S.add(Dense(512))
        self.S.add(LeakyReLU(0.2))
        self.S.add(Dense(512))
        self.S.add(LeakyReLU(0.2))

        # === Generator ===

        # Inputs
        inp_style = []

        for i in range(n_layers):
            inp_style.append(Input([512]))

        inp_noise = Input([im_size, im_size, 1])

        # Latent
        x = Lambda(lambda x: x[:, :1] * 0 + 1)(inp_style[0])

        outs = []

        # Actual Model
        x = Dense(
            4 * 4 * 4 * cha, activation="relu", kernel_initializer="random_normal"
        )(x)
        x = Reshape([4, 4, 4 * cha])(x)

        x, r = g_block(x, inp_style[0], inp_noise, 32 * cha, u=False)  # 4
        outs.append(r)

        x, r = g_block(x, inp_style[1], inp_noise, 16 * cha)  # 8
        outs.append(r)

        x, r = g_block(x, inp_style[2], inp_noise, 8 * cha)  # 16
        outs.append(r)

        x, r = g_block(x, inp_style[3], inp_noise, 6 * cha)  # 32
        outs.append(r)

        x, r = g_block(x, inp_style[4], inp_noise, 4 * cha)  # 64
        outs.append(r)

        x, r = g_block(x, inp_style[5], inp_noise, 2 * cha)  # 128
        outs.append(r)

        x, r = g_block(x, inp_style[6], inp_noise, 1 * cha)  # 256
        outs.append(r)

        x = add(outs)

        x = Lambda(lambda y: y / 2 + 0.5)(
            x
        )  # Use values centered around 0, but normalize to [0, 1], providing better initialization

        self.G = Model(inputs=inp_style + [inp_noise], outputs=x)

        return self.G

    def GenModel(self):

        # Generator Model for Evaluation

        inp_style = []
        style = []

        for i in range(n_layers):
            inp_style.append(Input([latent_size]))
            style.append(self.S(inp_style[-1]))

        inp_noise = Input([im_size, im_size, 1])

        gf = self.G(style + [inp_noise])

        self.GM = Model(inputs=inp_style + [inp_noise], outputs=gf)

        return self.GM

    def GenModelA(self):

        # Parameter Averaged Generator Model

        inp_style = []
        style = []

        for i in range(n_layers):
            inp_style.append(Input([latent_size]))
            style.append(self.SE(inp_style[-1]))

        inp_noise = Input([im_size, im_size, 1])

        gf = self.GE(style + [inp_noise])

        self.GMA = Model(inputs=inp_style + [inp_noise], outputs=gf)

        return self.GMA

    def EMA(self):

        # Parameter Averaging

        for i in range(len(self.G.layers)):
            up_weight = self.G.layers[i].get_weights()
            old_weight = self.GE.layers[i].get_weights()
            new_weight = []
            for j in range(len(up_weight)):
                new_weight.append(
                    old_weight[j] * self.beta + (1 - self.beta) * up_weight[j]
                )
            self.GE.layers[i].set_weights(new_weight)

        for i in range(len(self.S.layers)):
            up_weight = self.S.layers[i].get_weights()
            old_weight = self.SE.layers[i].get_weights()
            new_weight = []
            for j in range(len(up_weight)):
                new_weight.append(
                    old_weight[j] * self.beta + (1 - self.beta) * up_weight[j]
                )
            self.SE.layers[i].set_weights(new_weight)

    def MAinit(self):
        # Reset Parameter Averaging
        self.GE.set_weights(self.G.get_weights())
        self.SE.set_weights(self.S.get_weights())
Exemplo n.º 10
0
class MultiTaskModel(Sequential):
    def __init__(self,
                 image_shape,
                 num_labels,
                 num_inputs=4,
                 trainableVariables=None,
                 attention=None,
                 two_stage=False,
                 pix2pix=False):
        #num_inputs refers to input channels(edge,texture etc.)
        #image_shape is the shape of 1 image for reconstruction
        #TODO - kwargs support for segnet initializations
        super(MultiTaskModel, self).__init__()
        self.num_inputs = num_inputs
        self.image_shape = image_shape
        self.segnets = []
        self.attention = attention
        self.pix2pix = pix2pix
        self.two_stage = two_stage

        if self.attention == "self":
            self.attention_gates_rec = []
            self.attention_gates_pred = []
        if trainableVariables is None:
            self.trainableVariables = [
            ]  #Not to be confused with trainable_variables, which is read-only
        else:
            self.trainableVariables = trainableVariables
        for i in range(num_inputs):
            #TODO make better attention layers.
            self.segnets.append(SegNet())
            if self.attention == "self":
                self.attention_gates_rec.append(
                    SelfAttention([
                        Conv2D(filters=128,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal'),
                        Conv2D(filters=512,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal',
                               activation="sigmoid")
                    ]))
                self.attention_gates_pred.append(
                    SelfAttention([
                        Conv2D(filters=128,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal'),
                        Conv2D(filters=512,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal',
                               activation="sigmoid")
                    ]))
            elif self.attention == "multi":
                self.attention_gates_rec = CompleteAttention(
                    layers=[
                        Conv2D(filters=128,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal',
                               activation="relu"),
                        Conv2D(filters=128,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal'),
                        Conv2D(filters=512 * self.num_inputs,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal',
                               activation="sigmoid")
                    ],
                    num_streams=self.num_inputs)
                self.attention_gates_pred = CompleteAttention(
                    layers=[
                        Conv2D(filters=128,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal',
                               activation="relu"),
                        Conv2D(filters=128,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal'),
                        Conv2D(filters=512 * self.num_inputs,
                               kernel_size=3,
                               padding="same",
                               kernel_initializer='glorot_normal',
                               activation="sigmoid")
                    ],
                    num_streams=self.num_inputs)
        print("Image_Shape", image_shape)
        #TODO fix the reconstruction net to have some conv layers
        self.reconstruct_image = Sequential(
            [  #Conv2D(filters=512, kernel_size=2,padding="same",activation="relu"),Conv2D(filters=256, kernel_size=2,padding="valid"),
                Flatten(),
                BatchNormalization(axis=-1),
                Dense(1024),
                Dense(image_shape[0] * image_shape[1] * image_shape[2],
                      activation='sigmoid')
            ])
        #Change activation to relu
        #Uncomment the two lines below to enable classification
        self.predict_label = Sequential([
            Conv2D(filters=128,
                   kernel_size=2,
                   padding="same",
                   activation="relu"),
            Conv2D(filters=128, kernel_size=2, padding="valid"),
            Flatten(),  #Dense(1000),BatchNormalization(axis=-1),
            Dense(num_labels, activation='softmax')
        ])  #The loss function uses softmax, final preds as well
        if self.pix2pix:
            self.discriminator = Sequential()
            disc_layers = [
                ReshapeAndConcat(),
                Conv2D(128, 3, padding="valid"),
                MaxPool2D(pool_size=(2, 2)),
                LeakyReLU(),
                BatchNormalization(axis=-1),
                Conv2D(128, 3, padding="valid"),
                MaxPool2D(pool_size=(2, 2)),
                LeakyReLU(),
                Flatten(),
                BatchNormalization(axis=-1),
                Dense(32, activation='relu'),
                BatchNormalization(axis=-1),
                Dense(1,
                      activation='sigmoid',
                      kernel_initializer="glorot_normal")
            ]
            for l in disc_layers:
                self.discriminator.add(l)

    def examine_disc(self, X):
        for l in self.discriminator.layers[:-1]:
            X = l.call(X)
        return (X)

    def regularizer_naive(self, means, meansum, indx):
        return tf.math.abs(means[indx] / meansum)

    def regularizer_ratio(self, meansr, meansp, meansumr, meansump):
        sum = tf.math.abs(meansr[0] / meansumr - meansp[0] / meansump)
        for i in range(1, self.num_inputs):
            sum += tf.math.abs(meansr[i] / meansumr - meansp[i] / meansump)
        return sum

    def setTrainableVariables(self, trainableVariables=None):
        if trainableVariables is not None:
            self.trainableVariables = trainableVariables
            return
        for i in range(self.num_inputs):
            print("On segnet", i)
            self.trainableVariables += self.segnets[i].trainable_variables
        if self.attention == "self":
            for i in range(self.num_inputs):
                self.trainableVariables += self.attention_gates_rec[
                    i].trainable_variables
                self.trainableVariables += self.attention_gates_pred[
                    i].trainable_variables
        elif self.attention == "multi":
            self.trainableVariables += self.attention_gates_rec.trainable_variables
            self.trainableVariables += self.attention_gates_pred.trainable_variables

        self.trainableVariables += self.reconstruct_image.trainable_variables
        self.trainableVariables += self.predict_label.trainable_variables

        if self.pix2pix:
            self.disc_train_vars = []
            for l in self.discriminator.layers:
                self.disc_train_vars += l.trainable_variables

    # @tf.function
    def build(self, X):
        batch, h, w, c = X[0].shape
        assert len(X) == self.num_inputs
        result = []
        encoded_reps, rec = self.segnets[0].call(X[0])
        if self.attention == "self":
            encoded_reps_rec = self.attention_gates_rec[0].call(encoded_reps)
            encoded_reps_pred = self.attention_gates_pred[0].call(encoded_reps)
            # encoded_reps_rec = tf.expand_dims(encoded_reps_rec,1)
            # encoded_reps_pred = tf.expand_dims(encoded_reps_pred,1)
        else:
            # encoded_reps = tf.expand_dims(encoded_reps,1)
            pass
        result.append(rec)
        for i in range(self.num_inputs - 1):
            enc, rec = self.segnets[i + 1].call(X[i + 1])
            if self.attention == "self":
                encoded_attended_rec = self.attention_gates_rec[i + 1].call(
                    encoded_reps)
                encoded_attended_pred = self.attention_gates_pred[i + 1].call(
                    encoded_reps)
                # encoded_attended_rec = tf.expand_dims(encoded_attended_rec,1)
                # encoded_attended_pred = tf.expand_dims(encoded_attended_pred,1)
                print("-----------\n", encoded_reps_rec.shape,
                      encoded_attended_rec.shape)
                encoded_reps_rec = tf.concat(
                    [encoded_reps_rec, encoded_attended_rec], axis=-1)
                encoded_reps_pred = tf.concat(
                    [encoded_reps_pred, encoded_attended_pred], axis=-1)
            else:
                # enc = tf.expand_dims(enc,1)
                encoded_reps = tf.concat([encoded_reps, enc], axis=-1)

            result.append(rec)
        if self.attention == "multi":
            encoded_reps_rec, _, _ = self.attention_gates_rec(encoded_reps)
            encoded_reps_pred, _, _ = self.attention_gates_pred(encoded_reps)

        if self.attention is not None:
            result.append(
                tf.reshape(self.reconstruct_image(encoded_reps_rec),
                           (batch, h, w, c)))  #
            result.append(
                self.predict_label(encoded_reps_pred))  #Appending final labels
            if self.pix2pix:
                result.append(encoded_reps_rec)  #Needed for pix2pix
        else:
            result.append(
                tf.reshape(self.reconstruct_image(encoded_reps),
                           (batch, h, w, c)))  #
            result.append(
                self.predict_label(encoded_reps))  #Appending final labels
            if self.pix2pix:
                result.append(encoded_reps)  #Needed for pix2pix
        if self.pix2pix:
            self.discriminator.call((result[-1], result[self.num_inputs]))
            log = open("log_pix2pix.txt", "w")
            # log.write("Rec {}\n".format(self.discriminator((result[-1],result[self.num_inputs]))))
            # log.write("weights {}\n".format(self.discriminator.layers[-1].trainable_variables))
            log.write("\n")
            log.close()

    @tf.function
    def call(self, X, classification=False):
        #X is a LIST of the dimension [batch*h*w*c]*num_inputs
        #TODO check if this gives us correct appending upon flatten
        #TODO refactor to make everything a tensor
        batch, h, w, c = X[0].shape
        assert len(X) == self.num_inputs
        result = []
        encoded_reps, rec = self.segnets[0].call(X[0])
        if self.attention == "self":
            encoded_reps_rec = self.attention_gates_rec[0].call(encoded_reps)
            encoded_reps_pred = self.attention_gates_pred[0].call(encoded_reps)
            # encoded_reps_rec = tf.expand_dims(encoded_reps_rec,1)
            # encoded_reps_pred = tf.expand_dims(encoded_reps_pred,1)
        else:
            # encoded_reps = tf.expand_dims(encoded_reps,1)
            pass
        result.append(rec)
        for i in range(self.num_inputs - 1):
            enc, rec = self.segnets[i + 1].call(X[i + 1])
            if self.attention == "self":
                encoded_attended_rec = self.attention_gates_rec[i + 1].call(
                    encoded_reps)
                encoded_attended_pred = self.attention_gates_pred[i + 1].call(
                    encoded_reps)
                # encoded_attended_rec = tf.expand_dims(encoded_attended_rec,1)
                # encoded_attended_pred = tf.expand_dims(encoded_attended_pred,1)
                encoded_reps_rec = tf.concat(
                    [encoded_reps_rec, encoded_attended_rec], axis=-1)
                encoded_reps_pred = tf.concat(
                    [encoded_reps_pred, encoded_attended_pred], axis=-1)
            else:
                # enc = tf.expand_dims(enc,1)
                encoded_reps = tf.concat([encoded_reps, enc], axis=-1)

            result.append(rec)
        if self.attention == "multi":
            encoded_reps_rec, meansr, meansumr = self.attention_gates_rec(
                encoded_reps)
            encoded_reps_pred, meansp, meansump = self.attention_gates_pred(
                encoded_reps)

        if self.attention is not None:
            result.append(
                tf.reshape(self.reconstruct_image(encoded_reps_rec),
                           (batch, h, w, c)))  #
            result.append(
                self.predict_label(encoded_reps_pred))  #Appending final labels
            if self.pix2pix:
                result.append(encoded_reps_rec)  #Needed for pix2pix
            if self.attention == "multi":
                result.append(meansr)
                result.append(meansumr)
                result.append(meansp)
                result.append(meansump)
        else:
            result.append(
                tf.reshape(self.reconstruct_image(encoded_reps),
                           (batch, h, w, c)))  #
            result.append(
                self.predict_label(encoded_reps))  #Appending final labels
            if self.pix2pix:
                result.append(encoded_reps)  #Needed for pix2pix

        return result

    def loss_reconstruction(self, X, Y, beta=0.0):
        # print(X.shape,Y.shape)
        #Pixel-wise l2 loss
        # return  tf.math.reduce_sum(tf.math.reduce_sum(tf.math.reduce_sum((X-Y)**2,
        # axis=-1),axis=-1),axis=-1,keepdims=True)    #see if keepdims is required
        return (1 - beta) * tf.math.reduce_sum((X - Y)**2) / (
            X.shape[1] * X.shape[2] * X.shape[3]) + beta * tf.math.reduce_sum(
                tf.math.abs(X - Y)) / (X.shape[1] * X.shape[2] * X.shape[3])

    def loss_classification(self, X, labels):
        return (-1 *
                tf.reduce_mean(labels * (tf.math.log(X + 1e-10)) +
                               (1 - labels) * (tf.math.log(1 - X + 1e-10))))

    def generator_loss(self,
                       disc_generated_output,
                       gen_output,
                       target,
                       LAMBDA=0.1):
        gan_loss = self.loss_classification(
            disc_generated_output, tf.ones_like(disc_generated_output))

        # mean absolute error
        l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

        total_gen_loss = gan_loss + (LAMBDA * l1_loss)

        return total_gen_loss

    def discriminator_loss(self, disc_real_output, disc_generated_output):
        real_loss = self.loss_classification(disc_real_output,
                                             tf.ones_like(disc_real_output))

        generated_loss = self.loss_classification(
            disc_generated_output, tf.zeros_like(disc_generated_output))
        # log = open("log_pix2pix.txt","a")
        # log.write("Rec {}\n".format(self.examine_disc((result[-1],result[self.num_inputs]))))
        # log.write("Actual {}\n".format(self.examine_disc((result[-1],Y_image))))
        # log.write("Rec {}\n".format(self.discriminator((result[-1],result[self.num_inputs]))))
        # log.write("Actual {}\n".format(self.discriminator((result[-1],Y_image))))
        # log.write("loss {}\n".format(generated_loss))
        # log.write("weights {}\n".format(self.discriminator.layers[-1].trainable_variables))
        # log.close()

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss

    #TODO make this tf.function
    def train_on_batch(self,
                       X,
                       Y_image,
                       Y_labels,
                       optimizer,
                       classification=False):
        # Y needs to be a list of [img,labels]
        with tf.GradientTape(persistent=True) as tape:

            result = self.call(X)
            losses = []
            loss = 0
            loss_disc = 0
            if self.two_stage:
                if classification == False:
                    loss = self.loss_reconstruction(X[0], result[0])
                    losses.append(loss)
                    for i in range(self.num_inputs - 1):
                        loss += self.loss_reconstruction(
                            X[i + 1], result[i + 1])
                        losses.append(
                            self.loss_reconstruction(X[i + 1], result[i + 1]))
                    #TODO FIX THIS, since result[-1] is not the required thing anymore
                    if self.pix2pix:
                        disc_real_output = self.discriminator.call(
                            (result[-1], Y_image))
                        disc_generated_output = self.discriminator.call(
                            (result[-1], result[self.num_inputs]))
                        loss += self.generator_loss(disc_generated_output,
                                                    result[self.num_inputs],
                                                    Y_image)
                    else:
                        loss += self.loss_reconstruction(
                            result[self.num_inputs], Y_image)
                        losses.append(
                            self.loss_reconstruction(result[self.num_inputs],
                                                     Y_image))
                else:
                    #Uncomment the two lines below to enable classification
                    loss += self.loss_classification(
                        result[self.num_inputs + 1], Y_labels)
                    losses.append(
                        self.loss_classification(result[self.num_inputs + 1],
                                                 Y_labels))
                    if self.attention == 'multi':
                        loss += self.regularizer_ratio(
                            result[-2], result[-4], result[-1], result[-3]
                        )  #result[-1],result[-3]) #TODO Have this tunable

            else:
                loss = self.loss_reconstruction(X[0], result[0])
                losses.append(loss)
                for i in range(self.num_inputs - 1):
                    loss += self.loss_reconstruction(X[i + 1], result[i + 1])
                    losses.append(
                        self.loss_reconstruction(X[i + 1], result[i + 1]))
                if self.pix2pix:
                    disc_real_output = self.discriminator.call(
                        (result[-1], Y_image))
                    disc_generated_output = self.discriminator.call(
                        (result[-1], result[self.num_inputs]))
                    loss += self.generator_loss(disc_generated_output,
                                                result[self.num_inputs],
                                                Y_image)
                    loss_disc += self.discriminator_loss(
                        disc_real_output, disc_generated_output)
                    losses.append(loss_disc)
                    losses.append(
                        self.generator_loss(disc_generated_output,
                                            result[self.num_inputs], Y_image))
                else:
                    loss += self.loss_reconstruction(result[self.num_inputs],
                                                     Y_image)
                    losses.append(
                        self.loss_reconstruction(result[self.num_inputs],
                                                 Y_image))
                #Uncomment the two lines below to enable classification

                loss += self.loss_classification(result[self.num_inputs + 1],
                                                 Y_labels)
                losses.append(
                    self.loss_classification(result[self.num_inputs + 1],
                                             Y_labels))
                if self.attention == 'multi':
                    loss += self.regularizer_ratio(
                        result[-2], result[-4], result[-1],
                        result[-3])  #TODO Have this tunable

                # loss += self.regularizer_ratio(result[-4],result[-3],indx=0)

        if self.two_stage == True and classification == True:
            train_vars = self.predict_label.trainable_variables
            if self.attention == "self":
                for att in self.attention_gates_pred:
                    train_vars += att.trainable_variables
            elif self.attention == "multi":
                train_vars += self.attention_gates_pred.trainable_variables

            grads = tape.gradient(loss, train_vars)
            grads_and_vars = zip(grads, train_vars)

            optimizer.apply_gradients(grads_and_vars)

        else:
            grads = tape.gradient(loss, self.trainableVariables)
            grads_and_vars = zip(grads, self.trainableVariables)
            optimizer.apply_gradients(grads_and_vars)

        if self.pix2pix:
            grad_disc = tape.gradient(loss_disc, self.disc_train_vars)
            grads_and_vars_disc = zip(grad_disc, self.disc_train_vars)
            optimizer.apply_gradients(grads_and_vars_disc)

        del tape
        return loss, losses

    def validate_batch(self, X, Y_image, Y_labels):
        # Returns predictions, losses on batch
        result = self.call(X)
        losses = []
        loss = self.loss_reconstruction(X[0], result[0])
        losses.append(loss)
        for i in range(self.num_inputs - 1):
            loss += self.loss_reconstruction(X[i + 1], result[i + 1])
            losses.append(self.loss_reconstruction(X[i + 1], result[i + 1]))
        loss += self.loss_reconstruction(result[self.num_inputs], Y_image)
        # print("Loss: ",loss)
        losses.append(
            self.loss_reconstruction(result[self.num_inputs], Y_image))
        loss += self.loss_classification(result[self.num_inputs + 1], Y_labels)
        losses.append(
            self.loss_classification(result[self.num_inputs + 1], Y_labels))
        # print(result[-1].shape,Y_labels.shape,tf.math.argmax(result[-1],axis=1).numpy()==np.argmax(Y_labels,axis=1))
        if self.pix2pix:
            log = open("log_pix2pix.txt", "a")
            log.write("Rec {}\n".format(
                self.examine_disc((result[-1], result[self.num_inputs]))))
            log.write("Actual {}\n".format(
                self.examine_disc((result[-1], Y_image))))
            log.write("Rec {}\n".format(
                self.discriminator((result[-1], result[self.num_inputs]))))
            log.write("Actual {}\n".format(
                self.discriminator((result[-1], Y_image))))
            log.write("weights {}\n".format(
                self.discriminator.layers[-1].trainable_variables))
            log.close()
            return (tf.math.argmax(result[-2], axis=1).numpy() == np.argmax(
                Y_labels, axis=1)).sum(), losses
        else:
            return (tf.math.argmax(result[-1], axis=1).numpy() == np.argmax(
                Y_labels, axis=1)).sum(), losses

        # return losses
    def getAttentionMap(self, X):
        # Saves attention map for X
        if self.attention == "self":
            attention_maps_rec = []
            attention_maps_pred = []
            batch, h, w, c = X[0].shape
            # print("X.shape",h,w,c)
            assert len(X) == self.num_inputs
            result = []
            encoded_reps, rec = self.segnets[0].call(X[0])
            attention = self.attention_gates_rec[0].get_attention_map(
                encoded_reps).numpy()
            attention_maps_rec.append(attention)
            for i in range(self.num_inputs - 1):
                enc, rec = self.segnets[i + 1].call(X[i + 1])
                attention = self.attention_gates_rec[i + 1].get_attention_map(
                    encoded_reps).numpy()
                attention_maps_rec.append(
                    attention)  #Appending the reconstructed result to return

            encoded_reps, rec = self.segnets[0].call(X[0])
            attention = self.attention_gates_pred[0].get_attention_map(
                encoded_reps).numpy()
            attention_maps_pred.append(attention)
            for i in range(self.num_inputs - 1):
                enc, rec = self.segnets[i + 1].call(X[i + 1])
                attention = self.attention_gates_pred[i + 1].get_attention_map(
                    encoded_reps).numpy()
                attention_maps_pred.append(
                    attention)  #Appending the reconstructed result to return
        else:
            encoded_reps, _ = self.segnets[0].call(X[0])
            for i in range(self.num_inputs - 1):
                enc, _ = self.segnets[i + 1].call(X[i + 1])
                encoded_reps = tf.concat([encoded_reps, enc], axis=-1)

            attention_maps_rec = self.attention_gates_rec(
                encoded_reps)[0].numpy()
            attention_maps_pred = self.attention_gates_pred(
                encoded_reps)[0].numpy()
        result = tf.math.argmax(self.predict_label(
            tf.concat(attention_maps_pred, axis=-1)),
                                axis=1)

        return np.array(attention_maps_rec), np.array(
            attention_maps_pred), result.numpy()

    def save(self, modelDir):
        for i in range(len(self.segnets)):
            self.segnets[i].save("{}/Segnet-{}".format(modelDir, i))
        if self.attention == 'self':
            for i in range(len(self.segnets)):
                pickle.dump(
                    self.attention_gates_pred[i].get_weights(),
                    open("{}/Attention-Pred-{}".format(modelDir, i), "wb"))
                pickle.dump(
                    self.attention_gates_rec[i].get_weights(),
                    open("{}/Attention-Rec-{}".format(modelDir, i), "wb"))
        elif self.attention == 'multi':
            pickle.dump(self.attention_gates_pred.get_weights(),
                        open("{}/Attention-Pred".format(modelDir, i), "wb"))
            pickle.dump(self.attention_gates_rec.get_weights(),
                        open("{}/Attention-Rec".format(modelDir, i), "wb"))

        pickle.dump(self.reconstruct_image.get_weights(),
                    open("{}/Reconstruction-Model".format(modelDir), "wb"))
        pickle.dump(self.predict_label.get_weights(),
                    open("{}/Prediction-Model".format(modelDir), "wb"))

    def load_model(self, modelDir, attention, two_stage, pix2pix):
        for i in range(len(self.segnets)):
            self.segnets[i].load_model("{}/Segnet-{}".format(modelDir, i))
        rec_train_vars = pickle.load(
            open("{}/Reconstruction-Model".format(modelDir), "rb"))
        pred_train_vars = pickle.load(
            open("{}/Prediction-Model".format(modelDir), "rb"))
        # for l in self.reconstruct_image.layers:
        #   weights = rec_train_vars
        self.reconstruct_image.set_weights(rec_train_vars)
        # for l in self.predict_label.layers:
        #   weights = pred_train_vars
        self.predict_label.set_weights(pred_train_vars)
        self.TrainableVarsSet = False
        self.pix2pix = pix2pix
        self.attention = attention
        self.two_stage = two_stage

        if self.attention == "multi":
            pred_gates = pickle.load(
                open("{}/Attention-Pred".format(modelDir), "rb"))
            rec_gates = pickle.load(
                open("{}/Attention-Rec".format(modelDir), "rb"))
            self.attention_gates_pred.set_weights(pred_gates)
            self.attention_gates_rec.set_weights(rec_gates)
        elif self.attention == 'self':
            raise NotImplementedError