# Importing the necessary modules
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import tkinter as tk
import random as random
from tkinter import font,messagebox, Entry, Button
from scipy.stats import multivariate_normal
class App:
    
    def plot(self):
        self.plot_kalman_with_actor()
        self.root.mainloop()

    def plot_kalman_with_actor(self):

        if (self.type == 0):
            return
        
        #generate the distribution for visualizing
        random_seed=1000
        maximum = 0.0
        distr = multivariate_normal(cov = self.cov, mean = self.mean,
                                    seed = random_seed)
        x = np.linspace(-25, 25, num=200)
        y = np.linspace(-25, 25, num=200)
        X, Y = np.meshgrid(x,y)
        pdf = np.zeros(X.shape)
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                pdf[i,j] = distr.pdf([X[i,j], Y[i,j]])
                maximum = max(pdf[i,j],maximum)
        

        #plot the distribution
        norm = colors.PowerNorm(gamma=0.5, vmin=0, vmax=maximum, clip=False)
        plt.clf()
        plt.contourf(X, Y, pdf, cmap='plasma', norm=norm, levels=50)
        #add mean and actual location
        centerPoint = plt.Circle(self.mean, 0.25, color='green', fill=True)
        plt.gca().add_patch(centerPoint)
        plt.text(self.mean[0]+0.75,self.mean[1], "filter mean", color='black', bbox=dict(boxstyle="round", facecolor='white', alpha=0.3))
        actualPoint = plt.Circle(self.actualposition, 0.25, color='grey', fill=True)
        plt.gca().add_patch(actualPoint)
        plt.text(self.actualposition[0]+0.75,self.actualposition[1], "actual position", color='black', bbox=dict(boxstyle="round", facecolor='white', alpha=0.3))
        #only add the prediction and sensor points after movements start
        if (self.movements_made):
            movementPoint = plt.Circle(self.movement_prediction, 0.25, color='cyan', fill=True)
            plt.gca().add_patch(movementPoint)
            plt.text(self.movement_prediction[0]+0.75,self.movement_prediction[1], "movement model prediction", color='black', bbox=dict(boxstyle="round", facecolor='white', alpha=0.3))
            sensorPoint = plt.Circle(self.sensor_prediction, 0.25, color='lightsteelblue', fill=True)
            plt.gca().add_patch(sensorPoint)
            plt.text(self.sensor_prediction[0]+0.75,self.sensor_prediction[1], "sensor data", color='black', bbox=dict(boxstyle="round", facecolor='white', alpha=0.3))
        #add probability density legend
        cbar = plt.colorbar()
        cbar.set_label("PDF")
        plt.axis("equal")
        #add labels
        plt.xlabel("x")
        plt.ylabel("y")
        plt.title(f'Filter')
        plt.show()

    def __init__(self):
        plt.rcParams['figure.figsize']=7.5,6
        self.type=0
        self.fig = plt.figure()
        self.rng = np.random.default_rng()
        self._build_controls()
        
    def _build_controls(self):
        self._build_controls_for_kalman()
        self.root.mainloop()

    def _build_controls_for_kalman(self):
        self.root = tk.Tk()
        font.nametofont("TkDefaultFont").configure(size=16)
        font.nametofont("TkTextFont").configure(size=16)
        self.root.title("Filters")
        self.root.geometry("300x500")
        
        f_commands = tk.LabelFrame(self.root, text="Commands")
        f_commands.pack(fill="x", padx=10, pady=6)
        Button(f_commands, text="Set to model A", command=self.on_set_initial_position_for_type_3).pack(pady=8)
        Button(f_commands, text="Set to model B", command=self.on_set_initial_position_for_type_1).pack(pady=8)
        Button(f_commands, text="Set to model C", command=self.on_set_initial_position_for_type_4).pack(pady=8)
        Button(f_commands, text="Set to model D", command=self.on_set_initial_position_for_type_2).pack(pady=8)
        Button(f_commands, text="Reset", command=self.on_set_initial_position).pack(pady=8)
        Button(f_commands, text="Next step", command=self.on_move_and_measure).pack(pady=8)


    def get_intended_move(self):
        self.step=self.step+1
        if (self.type==1):
            cycle=self.step%12
            destx=np.cos(np.pi/6*cycle)*13
            desty=np.sin(np.pi/6*cycle)*13
            movex=destx-self.actualposition[0]
            movey=desty-self.actualposition[1]
            magnitude = np.sqrt(movex**2 + movey**2)
            if (magnitude > 8):
                movex = movex*7/magnitude
                movey = movey*7/magnitude
            return np.array([movex, movey])
        
        elif (self.type==2):
            cycle=-self.step%12
            destx=np.cos(np.pi/6*cycle)*13
            desty=np.sin(np.pi/6*cycle)*13
            movex=destx-self.actualposition[0]
            movey=desty-self.actualposition[1]
            magnitude = np.sqrt(movex**2 + movey**2)
            if (magnitude > 8):
                movex = movex*7/magnitude
                movey = movey*7/magnitude
            return np.array([movex, movey])

        elif (self.type==3):
            cycle=self.step%20
            if (cycle<10):
                destx = cycle*3-15
            else:
                destx = (20-cycle)*3-15

            cycle=self.step%10
            if (cycle<5):
                desty = -cycle*6+15
            else:
                desty = (cycle-10)*6+15

            movex=destx-self.actualposition[0]
            movey=desty-self.actualposition[1]
            magnitude = np.sqrt(movex**2 + movey**2)
            if (magnitude > 8):
                movex = movex*8/magnitude
                movey = movey*8/magnitude
            return np.array([movex, movey])

        elif (self.type==4):
            cycle=self.step%40
            if (cycle<10):
                destx = cycle*3-15
                desty = 15
            elif(cycle<20):
                destx = 15
                desty = (cycle-10)*-3+15
            elif(cycle<30):
                destx = (cycle-20)*-3+15
                desty = -15
            else:
                destx = -15
                desty = (cycle-30)*3-15

            movex=destx-self.actualposition[0]
            movey=desty-self.actualposition[1]
            magnitude = np.sqrt(movex**2 + movey**2)
            if (magnitude > 8):
                movex = movex*8/magnitude
                movey = movey*8/magnitude
            return np.array([movex, movey])

        
        return np.array([0, 0]) #should not get hit
    
    def on_move_and_measure(self):
        self.movements_made = True
               
        intended_move = self.get_intended_move()
        self.movement_prediction = self.mean + intended_move
        self.cov = self.cov + self.movement_noise_cov

        movement_noise = self.rng.multivariate_normal([0,0], self.movement_noise_cov, 1)
        if (self.type==4):
            if (random.random()<0.5):
                movement_noise=np.array([[5,0]])
            
        self.actualposition = self.actualposition + intended_move + movement_noise[0]

        sensor_noise = self.rng.multivariate_normal([0,0], self.sensor_noise_cov, 1)
        self.sensor_prediction = self.actualposition + sensor_noise[0]
        if (self.type==4):
            if (random.random()<0.2):
                self.sensor_prediction=np.array([0,0])

        measurement_error = self.sensor_prediction - self.movement_prediction
        measurement_cov = self.cov + self.sensor_noise_cov
        gain=np.dot(self.cov,np.linalg.inv(measurement_cov))
        self.mean = self.movement_prediction + np.dot(gain, np.transpose(measurement_error))
        self.cov = np.dot(np.eye(2) - gain, self.cov)
        self.plot()

    def on_set_initial_position(self):
        if (self.type==1):
            self.on_set_initial_position_for_type_1()        
        elif (self.type==2):
            self.on_set_initial_position_for_type_2()
        elif (self.type==3):
            self.on_set_initial_position_for_type_3()
        elif (self.type==4):
            self.on_set_initial_position_for_type_4()

    def on_set_initial_position_for_type_1(self):
        self.type = 1
        self.actualposition = np.array([12,0])
        self.mean = np.array([13,-1])
        self.cov = np.array([[1, 0],[0, 1]])
        self.movement_noise_cov = np.array([[0.5, 0],[0, 0.5]])
        self.sensor_noise_cov = np.array([[15, 0],[0, 15]])
        self.on_set_initial_position_finish()

    def on_set_initial_position_for_type_2(self):
        self.type = 2
        self.actualposition = np.array([12,0])
        self.mean = np.array([14,-3])
        self.cov = np.array([[25/12, 0],[0, 25/12]])
        self.movement_noise_cov = np.array([[15, 0],[0, 15]])
        self.sensor_noise_cov = np.array([[2, 0],[0, 2]])
        self.on_set_initial_position_finish()

    def on_set_initial_position_for_type_3(self):
        self.type = 3
        self.actualposition = np.array([-15,15])
        self.mean = np.array([-14,14])
        self.cov = np.array([[1, 0],[0, 1]])
        self.movement_noise_cov = np.array([[15, 0],[0, 1]])
        self.sensor_noise_cov = np.array([[2, 0],[0, 20]])
        self.on_set_initial_position_finish()

    def on_set_initial_position_for_type_4(self):
        self.type = 4
        self.actualposition = np.array([-15,15])
        self.mean = np.array([-15,15])
        self.cov = np.array([[25/12, 0],[0, 25/12]])
        self.movement_noise_cov = np.array([[1, 0],[0, 1]])
        self.sensor_noise_cov = np.array([[1, 0],[0, 1]])
        self.on_set_initial_position_finish()
        
    def on_set_initial_position_finish(self):
        self.movements_made = False
        self.step = 0
        self.plot()


if __name__ == "__main__":
    App()