import pygame, math
from tkinter import *
from tkinter import messagebox  # <-- needed for errors

# --- constants ---
WINSIZE = [1000, 1000]
white = 255, 255, 255
black = 0, 0, 0
blue = 50, 50, 255
red = 255, 0, 0
green = 0, 255, 0
grid_gray = (50, 50, 50)
axis_color = (220, 220, 220)
func_color = (255, 0, 0)
preimage_color = green
domain_band_color = (60, 60, 60, 60)  # RGBA for shaded domain band

GRID_SPACING = 50
CENTER = (WINSIZE[0] // 2, WINSIZE[1] // 2)

# Scale factors: pixels per math unit
X_SCALE = 50
Y_SCALE = 50

# View (math) bounds shown on screen (roughly)
VIEW_XMIN, VIEW_XMAX = -WINSIZE[0]//2 / X_SCALE, WINSIZE[0]//2 / X_SCALE

# --- Tk Exit ---
def Exit():
    master.destroy()
    pygame.quit()
    quit()

# --- grid + axes drawing ---
def draw_grid_and_axes(surface):
    surface.fill(black)

    # Vertical grid
    x = CENTER[0]
    while x >= 0:
        pygame.draw.line(surface, grid_gray, (x, 0), (x, WINSIZE[1]), 1)
        x -= GRID_SPACING
    x = CENTER[0] + GRID_SPACING
    while x <= WINSIZE[0]:
        pygame.draw.line(surface, grid_gray, (x, 0), (x, WINSIZE[1]), 1)
        x += GRID_SPACING

    # Horizontal grid
    y = CENTER[1]
    while y >= 0:
        pygame.draw.line(surface, grid_gray, (0, y), (WINSIZE[0], y), 1)
        y -= GRID_SPACING
    y = CENTER[1] + GRID_SPACING
    while y <= WINSIZE[1]:
        pygame.draw.line(surface, grid_gray, (0, y), (WINSIZE[0], y), 1)
        y += GRID_SPACING

    # Axes
    pygame.draw.line(surface, axis_color, (CENTER[0], 0), (CENTER[0], WINSIZE[1]), 3)  # Y-axis
    pygame.draw.line(surface, axis_color, (0, CENTER[1]), (WINSIZE[0], CENTER[1]), 3)  # X-axis

def draw_domain_band(surface, x_min, x_max):
    """Draw a translucent vertical band for the active domain X=[x_min,x_max]."""
    # clamp domain to [-10,10]
    x_min = max(-10, min(10, x_min))
    x_max = max(-10, min(10, x_max))
    if x_min > x_max:
        x_min, x_max = x_max, x_min

    px_left  = int(CENTER[0] + x_min * X_SCALE)
    px_right = int(CENTER[0] + x_max * X_SCALE)
    if px_left > px_right:
        px_left, px_right = px_right, px_left
    px_left = max(0, min(WINSIZE[0]-1, px_left))
    px_right = max(0, min(WINSIZE[0]-1, px_right))

    if px_right <= px_left:
        return

    band = pygame.Surface((px_right - px_left, WINSIZE[1]), pygame.SRCALPHA)
    band.fill(domain_band_color)
    surface.blit(band, (px_left, 0))

# --- function drawing ---
def draw_function(surface, func):
    prev_point = None
    for px in range(WINSIZE[0]):
        x = (px - CENTER[0]) / X_SCALE
        try:
            y = func(x)
            py = CENTER[1] - y * Y_SCALE
            pt = (px, py)
            if prev_point:
                pygame.draw.line(surface, func_color, prev_point, pt, 2)
            prev_point = pt
        except Exception:
            prev_point = None

# --- preimage plotting: plot green points on X-axis where f(x) = Y ---
def draw_preimage_points(surface, func, Y_value, x_min=-10, x_max=10, tol=1e-3):
    """
    Draw green points for solutions of f(x)=Y_value, only for x in [x_min, x_max].
    Also tries to catch flat/tangential zeros via derivative and quadratic refinement.
    """
    # clamp/sanitize domain
    x_min = float(max(-10, min(10, x_min)))
    x_max = float(max(-10, min(10, x_max)))
    if x_min > x_max:
        x_min, x_max = x_max, x_min

    # draw the horizontal reference line y = Y_value
    py_line = CENTER[1] - Y_value * Y_SCALE
    pygame.draw.line(surface, preimage_color, (0, py_line), (WINSIZE[0], py_line), 1)

    roots = []

    def g(x):
        return func(x) - Y_value
    
    # explicitly check mathematical endpoints ---
    try:
        if abs(g(x_min)) < tol:
            roots.append(x_min)
    except Exception:
        pass
    try:
        if abs(g(x_max)) < tol and (not roots or abs(x_max - roots[-1]) > 1e-6):
            roots.append(x_max)
    except Exception:
        pass

    # numeric derivative helpers for flat zeros
    h = 1.0 / X_SCALE  # ~one pixel in math units
    def deriv(x):
        try:
            return (g(x + h) - g(x - h)) / (2.0 * h)
        except Exception:
            return None

    def quad_refine_extremum(x):
        try:
            g1, g2, g3 = g(x - h), g(x), g(x + h)
            denom = (g1 - 2.0 * g2 + g3)
            if denom == 0:
                return x
            dx = 0.5 * h * (g1 - g3) / denom
            return x + dx
        except Exception:
            return x

    flat_y_tol = max(tol * 5, 1e-4)
    deriv_tol  = 1e-3
    min_sep    = 0.01

    # map [x_min, x_max] to pixel indices
    px_start = max(0, min(WINSIZE[0]-1, int(CENTER[0] + x_min * X_SCALE)))
    px_end   = max(0, min(WINSIZE[0]-1, int(CENTER[0] + x_max * X_SCALE)))
    if px_start > px_end:
        px_start, px_end = px_end, px_start

    # initialize
    px_prev = px_start
    x_prev = (px_prev - CENTER[0]) / X_SCALE
    try:
        g_prev = g(x_prev)
    except Exception:
        g_prev = None

    for px in range(px_start + 1, px_end + 1):
        x = (px - CENTER[0]) / X_SCALE
        try:
            gv = g(x)
        except Exception:
            g_prev = None
            px_prev = px
            continue

        if g_prev is None:
            g_prev = gv
            px_prev = px
            continue

        # near-exact root at current sample
        if abs(gv) < tol:
            if not roots or abs(x - roots[-1]) > min_sep:
                roots.append(x)

        # sign change -> root between x_prev and x
        elif g_prev * gv < 0:
            t = abs(g_prev) / (abs(g_prev) + abs(gv))
            x_root = x_prev + t * (x - x_prev)
            if not roots or abs(x_root - roots[-1]) > min_sep:
                roots.append(x_root)

        # flat-touching root (g ~ 0, g' ~ 0) -> refine with parabola
        elif abs(gv) < flat_y_tol:
            d = deriv(x)
            if d is not None and abs(d) < deriv_tol:
                x_star = quad_refine_extremum(x)
                try:
                    if abs(g(x_star)) < tol and (not roots or abs(x_star - roots[-1]) > min_sep):
                        roots.append(x_star)
                except Exception:
                    pass

        g_prev = gv
        x_prev = x

    # draw roots as small green circles on X-axis (y=0 => screen y = CENTER[1])
    for xr in roots:
        if x_min <= xr <= x_max:
            px = int(CENTER[0] + xr * X_SCALE)
            py = CENTER[1]
            if 0 <= px < WINSIZE[0]:
                pygame.draw.circle(surface, preimage_color, (px, py), 4)

    return roots

def parse_domain_fields():
    """Read x1, x2 from UI; return clamped, ordered (x1, x2)."""
    try:
        x1 = float(entryX1.get().strip())
        x2 = float(entryX2.get().strip())
    except Exception:
        # default domain
        return -10.0, 10.0
    x1 = max(-10, min(10, x1))
    x2 = max(-10, min(10, x2))
    if x1 > x2:
        x1, x2 = x2, x1
    return x1, x2

def replot():
    screen.fill(black)
    expr = entry.get().strip()
    if not expr:
        status_var.set("Enter an expression, e.g. sin(x)")
        return
    try:
        f = make_func(expr)
        draw_grid_and_axes(screen)
        # draw domain band (if fields present)
        x1, x2 = parse_domain_fields()
        draw_domain_band(screen, x1, x2)
        draw_function(screen, f)
        pygame.display.flip()
        status_var.set(f"Plotted: f(x) = {expr}")
    except Exception as e:
        status_var.set("Error")
        messagebox.showerror("Error", f"Could not parse/evaluate expression:\n{e}")

def compute_preimage():
    expr = entry.get().strip()
    if not expr:
        status_var.set("Enter an expression, e.g. sin(x)")
        return
    try:
        Y_value = float(entryY.get().strip())
    except Exception:
        messagebox.showerror("Error", "y must be a number (e.g. 0, 1.5, -2)")
        return

    # parse domain
    x1, x2 = parse_domain_fields()

    try:
        f = make_func(expr)
        # fresh plot for clarity
        replot()
        roots = draw_preimage_points(screen, f, Y_value, x_min=x1, x_max=x2)
        pygame.display.flip()
        if roots:
            status_var.set(f"Found {len(roots)} solution(s) for f(x) = {Y_value} on X=[{x1}, {x2}]")
        else:
            status_var.set(f"No solutions for f(x) = {Y_value} on X=[{x1}, {x2}]")
    except Exception as e:
        status_var.set("Error")
        messagebox.showerror("Error", f"Preimage calculation failed:\n{e}")

# very limited eval environment: math names + x
_allowed = {name: getattr(math, name) for name in dir(math) if not name.startswith("_")}
_allowed.update({"int": int, "abs": abs})  # allow Python abs()

def make_func(expr):
    def f(x):
        return eval(expr, {"__builtins__": {}}, {**_allowed, "x": x})
    return f

# --- init pygame ---
pygame.init()
screen = pygame.display.set_mode(WINSIZE)
pygame.display.set_caption('Preimage Visualization')

# initial draw
draw_grid_and_axes(screen)
draw_function(screen, math.sin)
pygame.display.flip()

# --- Tk UI ---
master = Tk()
master.title('Controls')
master.geometry("720x190")

# Row 1: function input
row1 = Frame(master)
row1.pack(fill=X, padx=10, pady=(8, 4))
Label(row1, text="f(x) =").pack(side=LEFT)
entry = Entry(row1, width=40)
entry.insert(0, "sin(x)")
entry.pack(side=LEFT, padx=6)
plot_btn = Button(row1, text="Plot", command=replot)
plot_btn.pack(side=LEFT, padx=6)

# Row 2: Y value + Preimage button
row2 = Frame(master)
row2.pack(fill=X, padx=10, pady=(2, 4))
Label(row2, text="Observation y =").pack(side=LEFT)
entryY = Entry(row2, width=12)
entryY.insert(0, "0")  # default target value
entryY.pack(side=LEFT, padx=6)
pre_btn = Button(row2, text="Compute preimage (f(x)=y) over the domain X", command=compute_preimage)
pre_btn.pack(side=LEFT, padx=6)

# Row 3: Domain X = [x1, x2]
row3 = Frame(master)
row3.pack(fill=X, padx=10, pady=(2, 8))
Label(row3, text="Domain X=[").pack(side=LEFT)
entryX1 = Entry(row3, width=10)
entryX1.insert(0, "-1")  # default x1
entryX1.pack(side=LEFT, padx=(4, 2))
Label(row3, text=",").pack(side=LEFT)
entryX2 = Entry(row3, width=10)
entryX2.insert(0, "3")   # default x2
entryX2.pack(side=LEFT, padx=(2, 4))
Label(row3, text="]").pack(side=LEFT)

# Status line
status_var = StringVar(value="Ready. Try: sin(x), cos(x), x**2, exp(-x**2), tan(x)")
status = Label(master, textvariable=status_var, anchor="w")
status.pack(fill=X, padx=10, pady=(0, 6))

# Shortcuts
entry.bind("<Return>", lambda e: replot())
entryY.bind("<Return>", lambda e: compute_preimage())
entryX1.bind("<Return>", lambda e: compute_preimage())
entryX2.bind("<Return>", lambda e: compute_preimage())

# Quit
quit_btn = Button(master, text="Quit", command=Exit, fg="red")
quit_btn.pack(pady=(2, 8))

master.mainloop()
