# TensorFlow core library
import tensorflow as tf

# Keras serialization utility
from keras.utils import register_keras_serializable  # type: ignore

@register_keras_serializable()
class WeightedSequenceLoss(tf.keras.losses.Loss):
    def __init__(self, weights, name="WeightedSequenceLoss"):
        super().__init__(name=name)
        self._weights_array = [float(w) for w in weights]
        self.name = name

    def call(self, y_true, y_pred):
        weights = tf.constant(self._weights_array, dtype=tf.float32)

        y_true = tf.cast(y_true, tf.int32)
        y_true_one_hot = tf.one_hot(y_true, depth=tf.shape(y_pred)[-1])
        sample_weights = tf.reduce_sum(weights * y_true_one_hot, axis=-1)

        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
        return tf.reduce_mean(loss * sample_weights)

    def get_config(self):
        return {
            "weights": self._weights_array,
            "name": self.name
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)