# Importing the necessary modules
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import tkinter as tk
from tkinter import font,messagebox, Entry, Button
from scipy.stats import multivariate_normal
class App:
    
    def plot(self):
        self.plot_kalman()
        self.root.mainloop()

    def plot_kalman(self):
        #generate the distribution for visualizing
        random_seed=1000
        maximum = 0.0
        distr = multivariate_normal(cov = self.cov, mean = self.mean,
                                    seed = random_seed)
        x = np.linspace(-20, 20, num=200)
        y = np.linspace(-20, 20, num=200)
        X, Y = np.meshgrid(x,y)
        pdf = np.zeros(X.shape)
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                pdf[i,j] = distr.pdf([X[i,j], Y[i,j]])
                maximum = max(pdf[i,j],maximum)
        
        #plot the distribution
        norm = colors.PowerNorm(gamma=0.5, vmin=0, vmax=maximum, clip=False)
        plt.clf()
        plt.contourf(X, Y, pdf, cmap='plasma', norm=norm, levels=50)
        centerPoint = plt.Circle(self.mean, 0.25, color='green', fill=True)
        #add mean point
        plt.gca().add_patch(centerPoint)
        #add probability density legend
        cbar = plt.colorbar()
        cbar.set_label("PDF")
        plt.axis("equal")
        #labels
        plt.xlabel("x")
        plt.ylabel("y")
        plt.title(f'Kalman Filter')
        plt.show()

    def __init__(self):
        plt.rcParams['figure.figsize']=7.5,6
        self.fig = plt.figure()
        self._build_controls()

    def _build_controls(self):
        self._build_controls_for_manual_kalman()
        self.on_set_noise()
        self.root.mainloop()

    def _build_controls_for_manual_kalman(self):
        self.root = tk.Tk()
        font.nametofont("TkDefaultFont").configure(size=16)
        font.nametofont("TkTextFont").configure(size=16)
        self.root.title("Kalman Filter")
        self.root.geometry("400x1000")

        
        f_initial = tk.LabelFrame(self.root, text="Initial Distribution")
        f_initial.pack(fill="x", padx=10, pady=6)

        frame_x_start = tk.Frame(f_initial); frame_x_start.pack(anchor="w", padx=10)
        tk.Label(frame_x_start, text="Starting x =").pack(side=tk.LEFT)
        self.x_start_entry = Entry(frame_x_start, width=8); self.x_start_entry.insert(0, "0")
        self.x_start_entry.pack(side=tk.LEFT)

        frame_y_start = tk.Frame(f_initial); frame_y_start.pack(anchor="w", padx=10)
        tk.Label(frame_y_start, text="Starting y =").pack(side=tk.LEFT)
        self.y_start_entry = Entry(frame_y_start, width=8); self.y_start_entry.insert(0, "0")
        self.y_start_entry.pack(side=tk.LEFT)

        frame_covariance = tk.Frame(f_initial); frame_covariance.pack(anchor="w")
        tk.Label(frame_covariance, text=f"Starting covariance").grid(row=1, column=0, columnspan=2, pady=2, sticky="w")
        self.xx_cov_entry = Entry(frame_covariance, width=8)
        self.xx_cov_entry.grid(row=2, column=0, pady=2, padx=2); 
        self.xx_cov_entry.insert(0, "1")
        self.xy_cov_entry = Entry(frame_covariance, width=8)
        self.xy_cov_entry.grid(row=2, column=1, pady=2, padx=2); 
        self.xy_cov_entry.insert(0, "0")
        self.xy_cov_entry.configure(state='disabled')
        self.yx_cov_entry = Entry(frame_covariance, width=8)
        self.yx_cov_entry.grid(row=3, column=0, pady=2, padx=2); 
        self.yx_cov_entry.insert(0, "0")
        callback = self.root.register(self._validate_yx_cov_entry)
        self.yx_cov_entry.configure(validate='key', validatecommand=(callback, "%P"))
        self.yy_cov_entry = Entry(frame_covariance, width=8)
        self.yy_cov_entry.grid(row=3, column=1, pady=2, padx=2); 
        self.yy_cov_entry.insert(0, "1")
        Button(f_initial, text="Set to this distribution", command=self.on_set_position).pack(pady=8)

        f_noise = tk.LabelFrame(self.root, text="Noise")
        f_noise.pack(fill="x", padx=10, pady=6)

        frame_movement_noise_covariance = tk.Frame(f_noise); frame_movement_noise_covariance.pack(anchor="w")
        tk.Label(frame_movement_noise_covariance, text=f"Movement covariance").grid(row=1, column=0, columnspan=2, pady=2, sticky="w")
        self.movement_noise_xx_cov_entry = Entry(frame_movement_noise_covariance, width=8)
        self.movement_noise_xx_cov_entry.grid(row=2, column=0, pady=2, padx=2); 
        self.movement_noise_xx_cov_entry.insert(0, "1")
        self.movement_noise_xy_cov_entry = Entry(frame_movement_noise_covariance, width=8)
        self.movement_noise_xy_cov_entry.grid(row=2, column=1, pady=2, padx=2); 
        self.movement_noise_xy_cov_entry.insert(0, "0")
        self.movement_noise_xy_cov_entry.configure(state='disabled')
        self.movement_noise_yx_cov_entry = Entry(frame_movement_noise_covariance, width=8)
        self.movement_noise_yx_cov_entry.grid(row=3, column=0, pady=2, padx=2); 
        self.movement_noise_yx_cov_entry.insert(0, "0")
        callback = self.root.register(self._validate_movement_noise_yx_cov_entry)
        self.movement_noise_yx_cov_entry.configure(validate='key', validatecommand=(callback, "%P"))
        self.movement_noise_yy_cov_entry = Entry(frame_movement_noise_covariance, width=8)
        self.movement_noise_yy_cov_entry.grid(row=3, column=1, pady=2, padx=2); 
        self.movement_noise_yy_cov_entry.insert(0, "1")

        frame_sensor_noise_covariance = tk.Frame(f_noise); frame_sensor_noise_covariance.pack(anchor="w")
        tk.Label(frame_sensor_noise_covariance, text=f"Sensor covariance").grid(row=1, column=0, columnspan=2, pady=2, sticky="w")
        self.sensor_noise_xx_cov_entry = Entry(frame_sensor_noise_covariance, width=8)
        self.sensor_noise_xx_cov_entry.grid(row=2, column=0, pady=2, padx=2); 
        self.sensor_noise_xx_cov_entry.insert(0, "1")
        self.sensor_noise_xy_cov_entry = Entry(frame_sensor_noise_covariance, width=8)
        self.sensor_noise_xy_cov_entry.grid(row=2, column=1, pady=2, padx=2); 
        self.sensor_noise_xy_cov_entry.insert(0, "0")
        self.sensor_noise_xy_cov_entry.configure(state='disabled')
        self.sensor_noise_yx_cov_entry = Entry(frame_sensor_noise_covariance, width=8)
        self.sensor_noise_yx_cov_entry.grid(row=3, column=0, pady=2, padx=2); 
        self.sensor_noise_yx_cov_entry.insert(0, "0")
        callback = self.root.register(self._validate_sensor_noise_yx_cov_entry)
        self.sensor_noise_yx_cov_entry.configure(validate='key', validatecommand=(callback, "%P"))
        self.sensor_noise_yy_cov_entry = Entry(frame_sensor_noise_covariance, width=8)
        self.sensor_noise_yy_cov_entry.grid(row=3, column=1, pady=2, padx=2); 
        self.sensor_noise_yy_cov_entry.insert(0, "1")
        Button(f_noise, text="Set these noise parameters", command=self.on_set_noise).pack(pady=8)

        f_movement = tk.LabelFrame(self.root, text="Add a movement")
        f_movement.pack(fill="x", padx=10, pady=6)

        frame_x_movement = tk.Frame(f_movement); frame_x_movement.pack(anchor="w", padx=10)
        tk.Label(frame_x_movement, text="Change in x =").pack(side=tk.LEFT)
        self.x_movement_entry = Entry(frame_x_movement, width=8); self.x_movement_entry.insert(0, "0")
        self.x_movement_entry.pack(side=tk.LEFT)

        frame_y_movement = tk.Frame(f_movement); frame_y_movement.pack(anchor="w", padx=10)
        tk.Label(frame_y_movement, text="Change in y =").pack(side=tk.LEFT)
        self.y_movement_entry = Entry(frame_y_movement, width=8); self.y_movement_entry.insert(0, "0")
        self.y_movement_entry.pack(side=tk.LEFT)

        Button(f_movement, text="Apply this movement", command=self.on_move_manual_kalman).pack(pady=8)

        f_sensor = tk.LabelFrame(self.root, text="Add a sensor measurement")
        f_sensor.pack(fill="x", padx=10, pady=6)

        frame_x_sensor = tk.Frame(f_sensor); frame_x_sensor.pack(anchor="w", padx=10)
        tk.Label(frame_x_sensor, text="Sensor x =").pack(side=tk.LEFT)
        self.x_sensor_entry = Entry(frame_x_sensor, width=8); self.x_sensor_entry.insert(0, "0")
        self.x_sensor_entry.pack(side=tk.LEFT)

        frame_y_sensor = tk.Frame(f_sensor); frame_y_sensor.pack(anchor="w", padx=10)
        tk.Label(frame_y_sensor, text="Sensor y =").pack(side=tk.LEFT)
        self.y_sensor_entry = Entry(frame_y_sensor, width=8); self.y_sensor_entry.insert(0, "0")
        self.y_sensor_entry.pack(side=tk.LEFT)

        Button(f_sensor, text="Apply this sensor reading", command=self.on_sense_manual_kalman).pack(pady=8)

    def _validate_sensor_noise_yx_cov_entry(self,input):
        self.sensor_noise_xy_cov_entry.configure(state='normal')
        self.sensor_noise_xy_cov_entry.delete(0, tk.END)
        self.sensor_noise_xy_cov_entry.insert(0, input)
        self.sensor_noise_xy_cov_entry.configure(state='disabled')
        return True

    def _validate_movement_noise_yx_cov_entry(self,input):
        self.movement_noise_xy_cov_entry.configure(state='normal')
        self.movement_noise_xy_cov_entry.delete(0, tk.END)
        self.movement_noise_xy_cov_entry.insert(0, input)
        self.movement_noise_xy_cov_entry.configure(state='disabled')
        return True

    def _validate_yx_cov_entry(self,input):
        self.xy_cov_entry.configure(state='normal')
        self.xy_cov_entry.delete(0, tk.END)
        self.xy_cov_entry.insert(0, input)
        self.xy_cov_entry.configure(state='disabled')
        return True

    def _parse_required_float(self, entry, name):
        s = entry.get().strip()
        if s == "":
            raise ValueError(f"Missing value for {name}")
        return float(s)
    
    def on_sense_manual_kalman(self):
        try:
            sensex = float(self._parse_required_float(self.x_sensor_entry, "sensor x"))
            sensey = float(self._parse_required_float(self.y_sensor_entry, "sensor y"))
            sensor_reading = np.array([sensex, sensey])
            measurement_error = sensor_reading - self.mean
            measurement_cov = self.cov + self.sensor_noise_cov
            gain=np.dot(self.cov,np.linalg.inv(measurement_cov))
            self.mean = self.mean + np.dot(gain, measurement_error)
            self.cov = np.dot(np.eye(2) - gain, self.cov)
            self.plot()
        except Exception as e:
            messagebox.showerror("Parameter reset error", str(e))

    def on_move_manual_kalman(self):
        try:
            movex = float(self._parse_required_float(self.x_movement_entry, "movement x"))
            movey = float(self._parse_required_float(self.y_movement_entry, "movement y"))
            movement = np.array([movex, movey])
            self.mean = self.mean + movement
            self.cov = self.cov + self.movement_noise_cov
            self.plot()
        except Exception as e:
            messagebox.showerror("Parameter reset error", str(e))
        
    def on_set_position(self):
        try:
            meanx = float(self._parse_required_float(self.x_start_entry, "starting x"))
            meany = float(self._parse_required_float(self.y_start_entry, "starting y"))
            self.mean = np.array([meanx, meany])
            covxx = float(self._parse_required_float(self.xx_cov_entry, "starting covariance xx"))
            covxy = float(self._parse_required_float(self.xy_cov_entry, "starting covariance xy"))
            covyx = float(self._parse_required_float(self.yx_cov_entry, "starting covariance yx"))
            covyy = float(self._parse_required_float(self.yy_cov_entry, "starting covariance yy"))
            self.cov = np.array([[covxx, covxy],[covyx, covyy]])
            self.plot()
        except Exception as e:
            messagebox.showerror("Parameter reset error", str(e))

    def on_set_noise(self):
        try:
            covxx = float(self._parse_required_float(self.movement_noise_xx_cov_entry, "movement noise covariance xx"))
            covxy = float(self._parse_required_float(self.movement_noise_xy_cov_entry, "movement noise covariance xy"))
            covyx = float(self._parse_required_float(self.movement_noise_yx_cov_entry, "movement noise covariance yx"))
            covyy = float(self._parse_required_float(self.movement_noise_yy_cov_entry, "movement noise covariance yy"))
            self.movement_noise_cov = np.array([[covxx, covxy],[covyx, covyy]])

            covxx = float(self._parse_required_float(self.sensor_noise_xx_cov_entry, "sensor noise covariance xx"))
            covxy = float(self._parse_required_float(self.sensor_noise_xy_cov_entry, "sensor noise covariance xy"))
            covyx = float(self._parse_required_float(self.sensor_noise_yx_cov_entry, "sensor noise covariance yx"))
            covyy = float(self._parse_required_float(self.sensor_noise_yy_cov_entry, "sensor noise covariance yy"))
            self.sensor_noise_cov = np.array([[covxx, covxy],[covyx, covyy]])
            
        except Exception as e:
            messagebox.showerror("Parameter reset error", str(e))


if __name__ == "__main__":
    App()