def __init__(self,
                 dim,
                 batch_norm,
                 dropout,
                 rec_dropout,
                 header,
                 mode,
                 partition,
                 ihm_pos,
                 target_repl=False,
                 depth=1,
                 input_dim=76,
                 size_coef=4,
                 **kwargs):

        print "==> not used params in network class:", kwargs.keys()

        self.dim = dim
        self.batch_norm = batch_norm
        self.dropout = dropout
        self.rec_dropout = rec_dropout
        self.depth = depth
        self.size_coef = size_coef

        # Parse channels
        channel_names = set()
        for ch in header:
            if ch.find("mask->") != -1:
                continue
            pos = ch.find("->")
            if pos != -1:
                channel_names.add(ch[:pos])
            else:
                channel_names.add(ch)
        channel_names = sorted(list(channel_names))
        print "==> found {} channels: {}".format(len(channel_names),
                                                 channel_names)

        channels = []  # each channel is a list of columns
        for ch in channel_names:
            indices = range(len(header))
            indices = filter(lambda i: header[i].find(ch) != -1, indices)
            channels.append(indices)

        # Input layers and masking
        X = Input(shape=(None, input_dim), name='X')
        mX = Masking()(X)

        # Masks
        ihm_M = Input(shape=(1, ), name='ihm_M')
        decomp_M = Input(shape=(None, ), name='decomp_M')
        los_M = Input(shape=(None, ), name='los_M')

        inputs = [X, ihm_M, decomp_M, los_M]

        # Preprocess each channel
        cX = []
        for ch in channels:
            cX.append(Slice(ch)(mX))
        pX = []  # LSTM processed version of cX
        for x in cX:
            p = x
            for i in range(depth):
                p = LSTM(units=dim,
                         activation='tanh',
                         return_sequences=True,
                         dropout=dropout,
                         recurrent_dropout=rec_dropout)(p)
            pX.append(p)

        # Concatenate processed channels
        Z = Concatenate(axis=2)(pX)

        # Main part of the network
        for i in range(depth):
            Z = LSTM(units=int(size_coef * dim),
                     activation='tanh',
                     return_sequences=True,
                     dropout=dropout,
                     recurrent_dropout=rec_dropout)(Z)
        L = Z

        if dropout > 0:
            L = Dropout(dropout)(L)

        # Output modules
        outputs = []

        ## ihm output

        # NOTE: masking for ihm prediction works this way:
        #   if ihm_M = 1 then we will calculate an error term
        #   if ihm_M = 0, our prediction will be 0 and as the label
        #   will also be 0 then error_term will be 0.
        if target_repl:
            ihm_seq = TimeDistributed(Dense(1, activation='sigmoid'),
                                      name='ihm_seq')(L)
            ihm_y = GetTimestep(ihm_pos)(ihm_seq)
            ihm_y = Multiply(name='ihm_single')([ihm_y, ihm_M])
            outputs += [ihm_y, ihm_seq]
        else:
            ihm_seq = TimeDistributed(Dense(1, activation='sigmoid'))(L)
            ihm_y = GetTimestep(ihm_pos)(ihm_seq)
            ihm_y = Multiply(name='ihm')([ihm_y, ihm_M])
            outputs += [ihm_y]

        ## decomp output
        decomp_y = TimeDistributed(Dense(1, activation='sigmoid'))(L)
        decomp_y = ExtendMask(name='decomp',
                              add_epsilon=True)([decomp_y, decomp_M])
        outputs += [decomp_y]

        ## los output
        if partition == 'none':
            los_y = TimeDistributed(Dense(1, activation='relu'))(L)
        else:
            los_y = TimeDistributed(Dense(10, activation='softmax'))(L)
        los_y = ExtendMask(name='los', add_epsilon=True)([los_y, los_M])
        outputs += [los_y]

        ## pheno output
        if target_repl:
            pheno_seq = TimeDistributed(Dense(25, activation='sigmoid'),
                                        name='pheno_seq')(L)
            pheno_y = LastTimestep(name='pheno_single')(pheno_seq)
            outputs += [pheno_y, pheno_seq]
        else:
            pheno_seq = TimeDistributed(Dense(25, activation='sigmoid'))(L)
            pheno_y = LastTimestep(name='pheno')(pheno_seq)
            outputs += [pheno_y]

        return super(Network, self).__init__(inputs=inputs, outputs=outputs)
    def __init__(self,
                 dim,
                 batch_norm,
                 dropout,
                 rec_dropout,
                 header,
                 task,
                 target_repl=False,
                 deep_supervision=False,
                 num_classes=1,
                 depth=1,
                 input_dim=76,
                 size_coef=4,
                 **kwargs):

        self.dim = dim
        self.batch_norm = batch_norm
        self.dropout = dropout
        self.rec_dropout = rec_dropout
        self.depth = depth
        self.size_coef = size_coef

        if task in ['decomp', 'ihm', 'ph']:
            final_activation = 'sigmoid'
        elif task in ['los']:
            if num_classes == 1:
                final_activation = 'relu'
            else:
                final_activation = 'softmax'
        else:
            raise ValueError("Wrong value for task")

        print("==> not used params in network class:", kwargs.keys())

        # Parse channels
        channel_names = set()
        for ch in header:
            if ch.find("mask->") != -1:
                continue
            pos = ch.find("->")
            if pos != -1:
                channel_names.add(ch[:pos])
            else:
                channel_names.add(ch)
        channel_names = sorted(list(channel_names))
        print("==> found {} channels: {}".format(len(channel_names),
                                                 channel_names))

        channels = []  # each channel is a list of columns
        for ch in channel_names:
            indices = range(len(header))
            indices = list(filter(lambda i: header[i].find(ch) != -1, indices))
            channels.append(indices)

        # Input layers and masking
        X = Input(shape=(None, input_dim), name='X')
        inputs = [X]
        mX = Masking()(X)

        if deep_supervision:
            M = Input(shape=(None, ), name='M')
            inputs.append(M)

        # Configurations
        is_bidirectional = True
        if deep_supervision:
            is_bidirectional = False

        # Preprocess each channel
        cX = []
        for ch in channels:
            cX.append(Slice(ch)(mX))
        pX = []  # LSTM processed version of cX
        for x in cX:
            p = x
            for i in range(depth):
                num_units = dim
                if is_bidirectional:
                    num_units = num_units // 2

                lstm = LSTM(units=num_units,
                            activation='tanh',
                            return_sequences=True,
                            dropout=dropout,
                            recurrent_dropout=rec_dropout)

                if is_bidirectional:
                    p = Bidirectional(lstm)(p)
                else:
                    p = lstm(p)
            pX.append(p)

        # Concatenate processed channels
        Z = Concatenate(axis=2)(pX)

        # Main part of the network
        for i in range(depth - 1):
            num_units = int(size_coef * dim)
            if is_bidirectional:
                num_units = num_units // 2

            lstm = LSTM(units=num_units,
                        activation='tanh',
                        return_sequences=True,
                        dropout=dropout,
                        recurrent_dropout=rec_dropout)

            if is_bidirectional:
                Z = Bidirectional(lstm)(Z)
            else:
                Z = lstm(Z)

        # Output module of the network
        return_sequences = (target_repl or deep_supervision)
        L = LSTM(units=int(size_coef * dim),
                 activation='tanh',
                 return_sequences=return_sequences,
                 dropout=dropout,
                 recurrent_dropout=rec_dropout)(Z)

        if dropout > 0:
            L = Dropout(dropout)(L)

        if target_repl:
            y = TimeDistributed(Dense(num_classes,
                                      activation=final_activation),
                                name='seq')(L)
            y_last = LastTimestep(name='single')(y)
            outputs = [y_last, y]
        elif deep_supervision:
            y = TimeDistributed(Dense(num_classes,
                                      activation=final_activation))(L)
            y = ExtendMask()([y, M])  # this way we extend mask of y to M
            outputs = [y]
        else:
            y = Dense(num_classes, activation=final_activation)(L)
            outputs = [y]

        super(Network, self).__init__(inputs=inputs, outputs=outputs)