import tensorflow as tf
from keras import layers, Model
from tensorflow.keras.layers import Normalization # type: ignore
from keras.utils import register_keras_serializable # type: ignore

# Register the Normalization layer globally
register_keras_serializable()(Normalization)

# Transformer model
@register_keras_serializable()
class Transformer(Model):
    def __init__(self,lapDependentNormalizer, nonLapDependentNormalizer,
                 maxSeqLength, lapDim, nonLapDependentDim, numLayers, numHeads,
                 d_model, ffnDim, dropout, numClasses, name=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.lapDependentNormalizer = lapDependentNormalizer
        self.nonLapDependentNormalizer = nonLapDependentNormalizer
        self.maxSeqLength = maxSeqLength
        self.lapDim = lapDim
        self.nonLapDependentDim = nonLapDependentDim
        self.numLayers = numLayers
        self.numHeads = numHeads
        self.d_model = d_model
        self.ffnDim = ffnDim
        self.dropout = dropout
        self.numClasses = numClasses

        # Input projection and positional encoding
        self.inputDense = layers.Dense(d_model)
        self.positionalEncoding = layers.Embedding(input_dim=maxSeqLength, output_dim=d_model)
        self.encoders = [
            layers.MultiHeadAttention(num_heads=numHeads, key_dim=d_model//numHeads)
            for _ in range(numLayers)
        ]
        self.ffns = [
            tf.keras.Sequential([
                layers.Dense(ffnDim, activation="relu"),
                layers.Dense(d_model),
            ]) for _ in range(numLayers)
        ]
        self.layerNorm1 = [layers.LayerNormalization() for _ in range(numLayers)]
        self.layerNorm2 = [layers.LayerNormalization() for _ in range(numLayers)]

        # Add nonLapDependent data
        self.nonLapDependentProj = layers.Dense(d_model)
        self.finalDense = layers.Dense(d_model, activation="relu")

        # Final BiLSTM and dense
        self.biLSTM = layers.Bidirectional(layers.LSTM(64, return_sequences=True))
        self.timeDense = layers.TimeDistributed(layers.Dense(64, activation="relu"))
        self.outDense = layers.Dense(numClasses, activation="softmax")

    def call(self, inputs, training=False, returnIntermediate=False):
        lapDependent = inputs["lapDependent"]
        mask = inputs["lapMask"]
        nonLapDependent = inputs["nonLapDependent"]

        # Normalization
        lapDependent = self.lapDependentNormalizer(lapDependent)
        nonLapDependent = self.nonLapDependentNormalizer(nonLapDependent)

        # Input project and positional encoding
        x = self.inputDense(lapDependent)
        pos = tf.range(self.maxSeqLength)[None, :]
        x  = x + self.positionalEncoding(pos)

        # Transformer stack
        attentionWeights = []
        for i in range(len(self.encoders)):
            # Multi-head attention
            attentionOut, scores = self.encoders[i](x, x, attention_mask=mask[:, None, None, :], return_attention_scores=True)
            attentionWeights.append(scores)
            x = self.layerNorm1[i](x + attentionOut)
            
            # Feed-forward network
            ffnOut = self.ffns[i](x)
            x = self.layerNorm2[i](x + ffnOut)

        # Add nonLapDependent context
        nonLapDepenedentProjection = self.nonLapDependentProj(nonLapDependent)
        nonLapDepenedentProjection = nonLapDepenedentProjection[:, None, :]
        nonLapDepenedentProjection = tf.tile(nonLapDepenedentProjection, [1, self.maxSeqLength, 1])
        x = tf.concat([x, nonLapDepenedentProjection], axis=-1)
        x = self.finalDense(x)

        # BiLSTM and TimeDistributed for further processing
        x = self.biLSTM(x, mask=mask, training=training)
        x = self.timeDense(x)

        # Pick last valid timestep
        lengths = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)
        batch = tf.range(tf.shape(x)[0])
        idx = tf.stack([batch, lengths - 1], axis=1)
        y = tf.gather_nd(x, idx)

        # Output
        output = self.outDense(y)

        if returnIntermediate:
            return output, {"attentionWeights": attentionWeights}
        return output
        
    def get_config(self):
        config = super().get_config()
        config.update({
            "lapDependentNormalizer": tf.keras.layers.serialize(self.lapDependentNormalizer),
            "nonLapDependentNormalizer": tf.keras.layers.serialize(self.nonLapDependentNormalizer),
            "maxSeqLength": self.maxSeqLength,
            "lapDim": self.lapDim,
            "nonLapDependentDim": self.nonLapDependentDim,
            "numLayers": self.numLayers,
            "numHeads": self.numHeads,
            "d_model": self.d_model,
            "ffnDim": self.ffnDim,
            "dropout": self.dropout,
            "numClasses": self.numClasses,
        })
        return config

    @classmethod
    def from_config(cls, config):
        config["lapDependentNormalizer"] = tf.keras.layers.deserialize(config["lapDependentNormalizer"])
        config["nonLapDependentNormalizer"] = tf.keras.layers.deserialize(config["nonLapDependentNormalizer"])
        return cls(**config)