import * as tf from '@tensorflow/tfjs';
import {DataPoint} from "../../components/interfaces/interfaces";

const chunkArray = (array: number[], chunkSize: number): number[][] => {
  const chunks = [];
  for (let i = 0; i < array.length; i += chunkSize) {
    chunks.push(array.slice(i, i + chunkSize));
  }
  return chunks;
}
export const processData = (data: DataPoint[]): DataPoint[] => {
  console.log(data);
  return data.map(row => ({
    timestamp: new Date(row.timestamp),
    oat: parseFloat(String(row.oat)),
    consumption: parseFloat(String(row.consumption)),
  }));
};
const minMaxScale = (data: number[]): { scaled: number[], min: number, max: number } => {
  const min = Math.min(...data);
  const max = Math.max(...data);
  const range = max - min;

  // Process in chunks to avoid call stack size exceeded error
  const chunkedData = chunkArray(data, 10000);
  const scaled = chunkedData.flatMap(chunk => chunk.map(val => (val - min) / range));

  return {scaled, min, max};
}

const minMaxInverseScale = (scaledData: number[], min: number, max: number): number[] => {
  const range = max - min;

  // Process in chunks to avoid call stack size exceeded error
  const chunkedScaledData = chunkArray(scaledData, 10000);
  const inversed = chunkedScaledData.flatMap(chunk => chunk.map(val => val * range + min));

  return inversed;
};

// Define the custom layer
class CustomLayer extends tf.layers.Layer {
  constructor() {
    super({});
  }

  call(inputs: tf.Tensor | tf.Tensor[], kwargs: any): tf.Tensor | tf.Tensor[] {
    const input = inputs as tf.Tensor;
    const linearPart = tf.mul(input, input).mul(0.1);
    const nonLinearPart = tf.sigmoid(input).sub(0.5).mul(0.2);
    return tf.add(linearPart, nonLinearPart);
  }

  static get className() {
    return 'CustomLayer';
  }
}

tf.serialization.registerClass(CustomLayer);

export const trainAndPredict = async (data: DataPoint[]) => {

  // Check for missing or invalid values
  if (data.some(row => isNaN(row.oat) || isNaN(row.consumption))) {
    const errorMessage = "Data contains NaN or invalid values.";
    console.error(errorMessage);

    return;
  }

  // Normalize data using TensorFlow.js
  const tempData = data.map(row => row.oat);
  const energyData = data.map(row => row.consumption);

  console.log('Raw Temperature Data:', tempData);
  console.log('Raw Energy Data:', energyData);

  const tempTensor = tf.tensor2d(tempData, [tempData.length, 1]);  // Ensure tensor is 2D
  const energyTensor = tf.tensor2d(energyData, [energyData.length, 1]);  // Ensure tensor is 2D

  const tempMin = tempTensor.min();
  const tempMax = tempTensor.max();
  const energyMin = energyTensor.min();
  const energyMax = energyTensor.max();

  const tempRange = tempMax.sub(tempMin);
  const energyRange = energyMax.sub(energyMin);



  const normalizedTemp = tempTensor.sub(tempMin).div(tempRange);
  const normalizedEnergy = energyTensor.sub(energyMin).div(energyRange);

  // Log normalized data for debugging
  console.log('Normalized Temperature Data:', await normalizedTemp.array());
  console.log('Normalized Energy Data:', await normalizedEnergy.array());

  const xs = normalizedTemp;
  const ys = normalizedEnergy;

  // Define the model
  const model = tf.sequential();
  model.add(tf.layers.dense({ units: 64, activation: 'relu', inputShape: [1] }));
  model.add(new CustomLayer());
  model.add(tf.layers.dense({ units: 1, activation: 'relu' })); // Ensure non-negative outputs

  // Custom loss function to penalize negative predictions and enforce monotonicity
  const customLoss = (yTrue: any, yPred: any) => {
    const penalty = tf.sum(tf.relu(tf.neg(yPred))); // Penalty for negative predictions
    const loss = tf.losses.meanSquaredError(yTrue, yPred).add(penalty);
    return loss;
  };

  model.compile({ optimizer: tf.train.adam(0.001), loss: customLoss });

await model.fit(xs, ys, {
  epochs: 5,
  verbose: 0,
  callbacks: {
    onEpochEnd: async (epoch, logs) => {
      if (isNaN(logs!.loss)) {
        const errorMessage = "Loss is NaN. Stopping training.";
        console.error(errorMessage);

        model.stopTraining = true;
      } else {
        console.log(`Epoch ${epoch + 1}: Loss = ${logs!.loss}`);
      }

    },
  },
});
console.log("Training complete!");
// Generate predictions
const predictionsTensor = model.predict(xs) as tf.Tensor;
const predictions = predictionsTensor.mul(energyMax.sub(energyMin)).add(energyMin);


const predictionsArray = await predictions.array();

console.log(predictionsArray);

// Cleanup tensors
// tf.dispose([tempTensor, energyTensor, normalizedTemp, normalizedEnergy, predictionsTensor, predictions]);

const getElement = (index: number) => {
  // @ts-ignore
  return predictionsArray[index][0]
};


return data.map((row, index) => ({
  oat: row.oat,
  consumption: getElement(index),
}));

}


