import math
import tkinter as tk
from tkinter import messagebox, StringVar, OptionMenu, Entry, Button

import numpy as np

import matplotlib
matplotlib.use("TkAgg")
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401

# ==============================
# Fixed environments
# ==============================

CIRCLE_RADIUS = 1.0
SQUARE_HALF = 1.0

# ------------------------------
# Environment membership
# ------------------------------

def in_env_circle(x, y):
    return x*x + y*y <= CIRCLE_RADIUS**2 + 1e-9

def in_env_square(x, y):
    return abs(x) <= SQUARE_HALF + 1e-9 and abs(y) <= SQUARE_HALF + 1e-9

# ------------------------------
# Closest-boundary distance
# ------------------------------

def distance_to_boundary_circle(x, y):
    return CIRCLE_RADIUS - math.sqrt(x*x + y*y)

def distance_to_boundary_square(x, y):
    return min(SQUARE_HALF - abs(x), SQUARE_HALF - abs(y))

# ------------------------------
# Heading ray intersection helpers
# ------------------------------

def distance_along_heading_circle(x, y, theta):
    dx = math.cos(theta)
    dy = math.sin(theta)

    eps = 1e-9
    r2 = x*x + y*y
    on_boundary = abs(r2 - CIRCLE_RADIUS**2) < eps

    # --- NEW: infimum exit-time behavior on the boundary ---
    if on_boundary:
        # radial dot heading
        if x*dx + y*dy >= 0:
            # outward or tangent → exit immediately in infimum sense
            return 0.0
        # else: strictly inward → must compute next exit normally

    # Ray–circle intersection
    B = 2 * (x*dx + y*dy)
    C = r2 - CIRCLE_RADIUS**2
    disc = B*B - 4*C

    if disc < 0:
        return None

    s = math.sqrt(disc)
    t1 = (-B + s) / 2
    t2 = (-B - s) / 2

    eps_t = 1e-6
    cand = [t for t in (t1, t2) if t > eps_t]

    return min(cand) if cand else None


def distance_along_heading_square(x, y, theta):
    dx = math.cos(theta)
    dy = math.sin(theta)
    eps = 1e-6

    # --- NEW: infimum exit-time behavior on the boundary ---
    on_right  = abs(x - 1.0) < 1e-9
    on_left   = abs(x + 1.0) < 1e-9
    on_top    = abs(y - 1.0) < 1e-9
    on_bottom = abs(y + 1.0) < 1e-9

    if (on_right and dx > 0) or (on_left and dx < 0) or (on_top and dy > 0) or (on_bottom and dy < 0):
        return 0.0

    ts = []

    if abs(dx) > 1e-9:
        for sign in (1, -1):
            t = (sign - x) / dx
            yhit = y + t*dy
            if t > eps and -1 <= yhit <= 1:
                ts.append(t)

    if abs(dy) > 1e-9:
        for sign in (1, -1):
            t = (sign - y) / dy
            xhit = x + t*dx
            if t > eps and -1 <= xhit <= 1:
                ts.append(t)

    return min(ts) if ts else None
# ------------------------------
# Forward / backward distances
# ------------------------------

def distance_along_forward(env_type, x, y, theta):
    return distance_along_heading_circle(x, y, theta) \
        if env_type == "Disk" else distance_along_heading_square(x, y, theta)

def distance_along_back(env_type, x, y, theta):
    return distance_along_forward(env_type, x, y, theta + math.pi)

# ------------------------------
# Sensor functions
# ------------------------------

def backward_distance(env_type, x, y, theta):
    return distance_along_back(env_type, x, y, theta)

def corridor_width(env_type, x, y, theta):
    f = distance_along_forward(env_type, x, y, theta)
    b = distance_along_back(env_type, x, y, theta)
    return None if (f is None or b is None) else f + b

def two_ray_aperture(env_type, x, y, theta, alpha):
    a1 = theta + alpha
    a2 = theta - alpha

    if env_type == "Disk":
        d1 = distance_along_heading_circle(x, y, a1)
        d2 = distance_along_heading_circle(x, y, a2)
    else:
        d1 = distance_along_heading_square(x, y, a1)
        d2 = distance_along_heading_square(x, y, a2)

    if d1 is None: return d2
    if d2 is None: return d1
    return min(d1, d2)

# ------------------------------
# Analytical singular samples
# ------------------------------

def add_origin_samples_if_needed(env_type, sensor_type, target_h, N_theta, tol):
    pts = []

    if env_type == "Disk" and sensor_type == "Distance to closest boundary":
        if abs(target_h - CIRCLE_RADIUS) <= tol:
            thetas = np.linspace(0, 2*math.pi, N_theta, endpoint=False)
            for th in thetas:
                pts.append((0.0, 0.0, th))

    return pts

def add_circle_boundary_samples(env_type, sensor_type, target_h,
                                alpha_deg, N_theta, tol):
    pts = []

    num_phi = 200
    phis = np.linspace(0, 2*math.pi, num_phi, endpoint=False)
    thetas = np.linspace(0, 2*math.pi, N_theta, endpoint=False)
    alpha = math.radians(alpha_deg)

    for phi in phis:
        x = CIRCLE_RADIUS * math.cos(phi)
        y = CIRCLE_RADIUS * math.sin(phi)

        for th in thetas:
            # --- evaluate the selected sensor ---
            if sensor_type == "Distance to closest boundary":
                h = distance_to_boundary_circle(x, y)

            elif sensor_type == "Distance along heading":
                h = distance_along_heading_circle(x, y, th)

            elif sensor_type == "Distance behind robot":
                h = backward_distance("Disk", x, y, th)

            elif sensor_type == "Corridor width along heading":
                h = corridor_width("Disk", x, y, th)

            elif sensor_type == "Two-ray aperture":
                h = two_ray_aperture("Disk", x, y, th, alpha)

            else:
                continue

            # --- preimage test ---
            if h is not None and abs(h - target_h) <= tol:
                pts.append((x, y, th))

    return pts

# ------------------------------
# Preimage computation
# ------------------------------

def compute_preimage(env_type, sensor_type, target_h, alpha_deg, N_xy, N_theta):
    xs = np.linspace(-1, 1, N_xy)
    ys = np.linspace(-1, 1, N_xy)

    inside = in_env_circle if env_type == "Disk" else in_env_square
    tol = (2 / (N_xy - 1)) * 0.7

    thetas = np.linspace(0, 2*math.pi, N_theta, endpoint=False)
    alpha = math.radians(alpha_deg)

    pts = []

    for th in thetas:
        for x in xs:
            for y in ys:
                if not inside(x, y):
                    continue

                if sensor_type == "Distance to closest boundary":
                    h = distance_to_boundary_circle(x, y) if env_type == "Disk" \
                        else distance_to_boundary_square(x, y)

                elif sensor_type == "Distance along heading":
                    h = distance_along_forward(env_type, x, y, th)

                elif sensor_type == "Distance behind robot":
                    h = backward_distance(env_type, x, y, th)

                elif sensor_type == "Corridor width along heading":
                    h = corridor_width(env_type, x, y, th)

                elif sensor_type == "Two-ray aperture":
                    h = two_ray_aperture(env_type, x, y, th, alpha)

                else:
                    raise ValueError("Unknown sensor")

                if h is not None and abs(h - target_h) <= tol:
                    pts.append((x, y, th))

    # Inject analytical samples
    pts.extend(
        add_origin_samples_if_needed(
            env_type, sensor_type, target_h, N_theta, tol
        )
    )

    if env_type == "Disk":
        pts.extend(
            add_circle_boundary_samples(
                env_type,
                sensor_type,
                target_h,
                alpha_deg,
                N_theta,
                tol
            )
        )

    return np.array(pts) if pts else np.empty((0, 3))

# ------------------------------
# Environment boundary plotting
# ------------------------------

def plot_disk_boundary(ax):
    t = np.linspace(0, 2*math.pi, 200)
    ax.plot(np.cos(t), np.sin(t), np.zeros_like(t), color="red")

def plot_square_boundary(ax):
    xs = [-1, 1, 1, -1, -1]
    ys = [-1, -1, 1, 1, -1]
    ax.plot(xs, ys, [0]*5, color="red")

# ------------------------------
# Plotting
# ------------------------------

fig = None
ax = None

def plot_preimage(points, env_type, sensor_type, target_h):
    global fig, ax

    if fig is None:
        plt.ion()
        fig = plt.figure("Preimage", figsize=(7, 7))
        ax = fig.add_subplot(111, projection="3d")
    else:
        ax.cla()

    if len(points):
        ax.scatter(points[:,0], points[:,1], points[:,2], s=6)

    plot_disk_boundary(ax) if env_type == "Disk" else plot_square_boundary(ax)

    ax.set_xlim(-1,1)
    ax.set_ylim(-1,1)
    ax.set_zlim(0,2*math.pi)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("theta")
    ax.set_title(f"{sensor_type}, h = {target_h:.3f}")

    fig.canvas.draw()

# ------------------------------
# GUI callbacks
# ------------------------------

def read_resolution():
    N_xy = int(nxy_entry.get())
    N_theta = int(ntheta_entry.get())
    if N_xy < 5 or N_theta < 4:
        raise ValueError("Use N_xy >= 5 and N_theta >= 4")
    return N_xy, N_theta

def on_compute():
    try:
        h = float(h_entry.get())
        alpha = float(alpha_entry.get())
        N_xy, N_theta = read_resolution()
    except Exception as e:
        messagebox.showerror("Input Error", str(e))
        return

    pts = compute_preimage(env_var.get(), sensor_var.get(), h, alpha, N_xy, N_theta)
    plot_preimage(pts, env_var.get(), sensor_var.get(), h)

def play_animation():
    try:
        hmin = float(hmin_entry.get())
        hmax = float(hmax_entry.get())
        steps = int(steps_entry.get())
        alpha = float(alpha_entry.get())
        N_xy, N_theta = read_resolution()
    except Exception as e:
        messagebox.showerror("Input Error", str(e))
        return

    hs = np.linspace(hmin, hmax, steps)

    def animate(i=0):
        if i >= len(hs):
            return
        pts = compute_preimage(env_var.get(), sensor_var.get(), hs[i], alpha, N_xy, N_theta)
        plot_preimage(pts, env_var.get(), sensor_var.get(), hs[i])
        root.after(400, animate, i+1)

    animate()

# ------------------------------
# GUI
# ------------------------------

root = tk.Tk()
root.title("Robot Sensor Preimage Visualizer")

frame_env = tk.Frame(root); frame_env.pack(anchor="w", padx=10)
tk.Label(frame_env, text="Environment =").pack(side=tk.LEFT)
env_var = StringVar(value="Disk")
OptionMenu(frame_env, env_var, "Disk", "Square").pack(side=tk.LEFT)

frame_sensor = tk.Frame(root); frame_sensor.pack(anchor="w", padx=10)
tk.Label(frame_sensor, text="Sensor =").pack(side=tk.LEFT)
sensor_var = StringVar(value="Distance to closest boundary")
OptionMenu(frame_sensor, sensor_var,
           "Distance to closest boundary",
           "Distance along heading",
           "Distance behind robot",
           "Corridor width along heading",
           "Two-ray aperture").pack(side=tk.LEFT)

frame_h = tk.Frame(root); frame_h.pack(anchor="w", padx=10)
tk.Label(frame_h, text="h =").pack(side=tk.LEFT)
h_entry = Entry(frame_h, width=8); h_entry.insert(0, "0.2")
h_entry.pack(side=tk.LEFT)

frame_alpha = tk.Frame(root); frame_alpha.pack(anchor="w", padx=10)
tk.Label(frame_alpha, text="ℎ₅ aperture α (deg.) =").pack(side=tk.LEFT)
alpha_entry = Entry(frame_alpha, width=8); alpha_entry.insert(0, "30")
alpha_entry.pack(side=tk.LEFT)

frame_res = tk.Frame(root); frame_res.pack(anchor="w", padx=10)
tk.Label(frame_res, text="N_xy =").pack(side=tk.LEFT)
nxy_entry = Entry(frame_res, width=6); nxy_entry.insert(0, "100")
nxy_entry.pack(side=tk.LEFT, padx=5)

tk.Label(frame_res, text="N_theta =").pack(side=tk.LEFT)
ntheta_entry = Entry(frame_res, width=6); ntheta_entry.insert(0, "30")
ntheta_entry.pack(side=tk.LEFT)

Button(root, text="Compute", command=on_compute).pack(pady=8)

tk.Label(root, text="Animation", font=("Arial", 10, "bold")).pack(anchor="w", padx=10)

frame_hmin = tk.Frame(root); frame_hmin.pack(anchor="w", padx=10)
tk.Label(frame_hmin, text="h_min =").pack(side=tk.LEFT)
hmin_entry = Entry(frame_hmin, width=8); hmin_entry.insert(0, "0.0")
hmin_entry.pack(side=tk.LEFT)

frame_hmax = tk.Frame(root); frame_hmax.pack(anchor="w", padx=10)
tk.Label(frame_hmax, text="h_max =").pack(side=tk.LEFT)
hmax_entry = Entry(frame_hmax, width=8); hmax_entry.insert(0, "1.0")
hmax_entry.pack(side=tk.LEFT)

frame_steps = tk.Frame(root); frame_steps.pack(anchor="w", padx=10)
tk.Label(frame_steps, text="# of steps =").pack(side=tk.LEFT)
steps_entry = Entry(frame_steps, width=8); steps_entry.insert(0, "10")
steps_entry.pack(side=tk.LEFT)

frame_anim_btn = tk.Frame(root); frame_anim_btn.pack(pady=8)
Button(frame_anim_btn, text="Play Animation", command=play_animation).pack(side=tk.LEFT, padx=5)
Button(frame_anim_btn, text="Quit", fg="red", command=root.destroy).pack(side=tk.LEFT)

root.mainloop()
