import pygame, math
from tkinter import *
from tkinter import messagebox
import pygame.freetype

# =========================
# Window & drawing settings
# =========================
WINSIZE = (1100, 900)
BG = (0, 0, 0)
GRID = (60, 60, 60)     # grid/tick color
AXIS = (220, 220, 220)  # main axes (t=0, y=0)
ACC_COLOR = (255, 80, 80)
VEL_COLOR = (60, 220, 120)
POS_COLOR = (80, 160, 255)
CURSOR_DOT = (255, 230, 90)
TIME_CURSOR = (230, 230, 230)

LEFT_PAD = 70
RIGHT_PAD = 30
TOP_PAD = 30
BOTTOM_PAD = 30
PANEL_GAP = 30

# -------------
# Safe eval env
# -------------
_allowed = {name: getattr(math, name) for name in dir(math) if not name.startswith("_")}
_allowed.update({"pi": math.pi, "e": math.e})

def Heaviside(u): return 1.0 if u >= 0 else 0.0
def clamp(a, lo, hi): return max(lo, min(hi, a))
_allowed.update({"Heaviside": Heaviside, "clamp": clamp})

# ===========  Nice ticks  ===========
def nice_ticks(a, b, target=6):
    """Return 'nice' tick values between a and b using 1–2–5 rule."""
    if a == b:
        span = 1.0
        lo = a - 0.5
        hi = b + 0.5
    else:
        lo, hi = (a, b) if a < b else (b, a)
        span = hi - lo
    raw = span / max(1, target)
    mag = 10 ** math.floor(math.log10(raw)) if raw > 0 else 1
    for m in (1, 2, 5, 10):
        step = m * mag
        if span / step <= target:
            break
    start = math.ceil(lo / step) * step
    ticks, v = [], start
    while v <= hi + 1e-12:
        rd = round(v / step) * step
        ticks.append(rd)
        v += step
    return ticks

# ===========  Integration  ===========
def integrate_motion(accel_func, t0, t1, dt, x0, v0):
    """
    Euler integration: given ẍ(t), return arrays t, a, v, x on [t0, t1]
    """
    if dt <= 0:
        raise ValueError("dt must be > 0")
    if t1 < t0:
        t0, t1 = t1, t0
    n = int((t1 - t0)/dt) + 1
    T = [t0 + i*dt for i in range(n)]
    A = [0.0]*n
    V = [0.0]*n
    X = [0.0]*n
    V[0] = v0
    X[0] = x0
    try:
        A[0] = accel_func(T[0])
    except Exception:
        A[0] = float("nan")
    for i in range(1, n):
        t = T[i]
        try:
            A[i] = accel_func(t)
        except Exception:
            A[i] = float("nan")
        # Euler step using previous sample
        if not (math.isfinite(A[i-1]) and math.isfinite(V[i-1]) and math.isfinite(X[i-1])):
            A[i-1] = 0.0 if not math.isfinite(A[i-1]) else A[i-1]
            V[i-1] = 0.0 if not math.isfinite(V[i-1]) else V[i-1]
            X[i-1] = 0.0 if not math.isfinite(X[i-1]) else X[i-1]
        V[i] = V[i-1] + A[i-1]*dt
        X[i] = X[i-1] + V[i-1]*dt
    return T, A, V, X

# -------------  Expression -> ẍ(t)  -------------
def make_accel(expr):
    expr = expr.strip()
    if not expr:
        raise ValueError("Acceleration expression is empty.")
    def a(t):
        # booleans act as 1/0, so (2<=t)*(t<3)*2 style works
        return float(eval(expr, {"__builtins__": {}}, {**_allowed, "t": float(t)}))
    _ = a(0.0)  # quick sanity check
    return a

# ====================  Drawing helpers  ====================
def panel_rects():
    """Return (acc_rect, vel_rect, pos_rect) as pygame.Rect stacked vertically."""
    total_h = WINSIZE[1] - TOP_PAD - BOTTOM_PAD - 2*PANEL_GAP
    panel_h = total_h // 3
    w = WINSIZE[0] - LEFT_PAD - RIGHT_PAD
    acc = pygame.Rect(LEFT_PAD, TOP_PAD, w, panel_h)
    vel = pygame.Rect(LEFT_PAD, TOP_PAD + panel_h + PANEL_GAP, w, panel_h)
    pos = pygame.Rect(LEFT_PAD, vel.bottom + PANEL_GAP, w, panel_h)
    return acc, vel, pos

def draw_panel_border(surface, r):
    pygame.draw.rect(surface, (90, 90, 90), r, 1)

def map_series_to_panel(surface, r, T, Y, color, t0, t1, ymin=None, ymax=None, lw=2):
    """
    Draws:
      - 'nice' grid using data ticks on t and y
      - axes at t=0 and y=0 (inclusive if on boundary)
      - the series polyline
    Returns (y_min_used, y_max_used, x_axis_pix_for_t0) so overlays can match the scaling.
    """
    # auto y-range
    finite_vals = [v for v in Y if math.isfinite(v)]
    if not finite_vals:
        draw_panel_border(surface, r)
        return None, None, r.left
    y_min = min(finite_vals) if ymin is None else ymin
    y_max = max(finite_vals) if ymax is None else ymax
    if y_min == y_max:
        y_min -= 1
        y_max += 1
    pad = 0.10*(y_max - y_min)
    y_min -= pad
    y_max += pad

    # --------- grid with nice ticks ---------
    # vertical grid (time)
    for tt in nice_ticks(t0, t1, target=7):
        xpix = r.left + int((tt - t0)/(t1 - t0) * r.width) if t1 != t0 else r.left
        pygame.draw.line(surface, GRID, (xpix, r.top), (xpix, r.bottom), 1)
        put_tick_text(surface, xpix-10 if (abs(tt) >= 10) else xpix-3, r.bottom+4, str(int(tt)))

    # horizontal grid (value)
    for yy in nice_ticks(y_min, y_max, target=6):
        ypix = r.bottom - int((yy - y_min)/(y_max - y_min) * r.height) if y_max != y_min else r.bottom
        pygame.draw.line(surface, GRID, (r.left, ypix), (r.right, ypix), 1)
        put_tick_text(surface, r.left-22 if (abs(yy) >= 10) else r.left-15, ypix-4, str(int(yy)))

    # panel border
    draw_panel_border(surface, r)

    # --------- axes at 0 (inclusive) ----------
    if y_min <= 0 <= y_max and y_max != y_min:
        y0 = r.bottom - int((0 - y_min)/(y_max - y_min) * r.height)
        pygame.draw.line(surface, AXIS, (r.left, y0), (r.right, y0), 2)
    axis_xpix = r.left
    if t0 <= 0 <= t1 and t1 != t0:
        axis_xpix = r.left + int((0 - t0)/(t1 - t0) * r.width)
        pygame.draw.line(surface, AXIS, (axis_xpix, r.top), (axis_xpix, r.bottom), 2)

    # --------- series polyline ----------
    prev = None
    span_t = (t1 - t0) if t1 != t0 else 1.0
    span_y = (y_max - y_min) if y_max != y_min else 1.0
    for t, y in zip(T, Y):
        xpix = r.left + int((t - t0)/span_t * r.width)
        ypix = r.bottom - int((y - y_min)/span_y * r.height)
        if prev:
            pygame.draw.line(surface, color, prev, (xpix, ypix), lw)
        prev = (xpix, ypix)

    return (y_min, y_max, axis_xpix)

def put_tick_text(surface, x, y, text):
    TICK_FONT.render_to(surface, (x, y), text, (255,255,255))

def put_label(surface, r, text, color):
    """Draws label text with proper Unicode rendering (e.g., ẋ, ẍ)."""
    LABEL_FONT.render_to(surface, (r.left + 6, r.top + 6), text, color)

def put_status(surface, text):
    #font = pygame.font.SysFont(None, 24)
    #im = font.render(text, True, (230, 230, 230))
    #surface.blit(im, (LEFT_PAD, WINSIZE[1] - 28))
    STATUS_FONT.render_to(surface, (LEFT_PAD, WINSIZE[1] - 28), text, (230, 230, 230))

def draw_time_cursor(surface, r, t0, t1, t_cur, color=TIME_CURSOR):
    """Draw a vertical line at current time t_cur within panel rect r. Returns the x pixel."""
    if t1 == t0:
        return r.left
    xpix = r.left + int((t_cur - t0) / (t1 - t0) * r.width)
    xpix = max(r.left, min(r.right, xpix))
    pygame.draw.line(surface, color, (xpix, r.top), (xpix, r.bottom), 1)
    return xpix

def pos_value_to_ypix(pos_rect, y_val, y_min, y_max):
    """Map a position value to a y pixel in the pos panel."""
    if y_min is None or y_max is None:
        return pos_rect.centery
    span_y = (y_max - y_min) if y_max != y_min else 1.0
    ypix = pos_rect.bottom - int((y_val - y_min)/span_y * pos_rect.height)
    return max(pos_rect.top, min(pos_rect.bottom, ypix))

def draw_position_dot_on_axis(surface, pos_rect, X, idx, y_min, y_max, axis_xpix):
    """Draw a dot at the vertical axis (t=0) whose Y matches x(t)."""
    if y_min is None or y_max is None or not X:
        return
    ypix = pos_value_to_ypix(pos_rect, X[idx], y_min, y_max)
    pygame.draw.circle(surface, CURSOR_DOT, (axis_xpix, ypix), 8)

# =========  Playback state  =========
last_T = last_A = last_V = last_X = []
last_t0 = 0.0
last_t1 = 10.0
play_idx = 0
is_playing = False

def redraw_from_last(cursor_idx=None):
    """Redraw panels, draw time cursor across all, and connect dot to cursor on x(t)."""
    screen.fill(BG)
    acc_rect, vel_rect, pos_rect = panel_rects()

    if not last_T:
        for r in (acc_rect, vel_rect, pos_rect):
            draw_panel_border(screen, r)
        pygame.display.flip()
        return

    T, A, V, X = last_T, last_A, last_V, last_X
    t0, t1 = last_t0, last_t1
    idx = 0 if cursor_idx is None else int(max(0, min(len(T)-1, cursor_idx)))
    t_cur = T[idx]

    # draw series and capture pos panel scaling/axis
    map_series_to_panel(screen, acc_rect, T, A, ACC_COLOR, t0, t1)
    map_series_to_panel(screen, vel_rect, T, V, VEL_COLOR, t0, t1)
    yinfo = map_series_to_panel(screen, pos_rect, T, X, POS_COLOR, t0, t1)
    if yinfo is None:
        y_min = y_max = None
        axis_xpix = pos_rect.left
    else:
        y_min, y_max, axis_xpix = yinfo

    # labels
    ACC_LABEL = " x\u0308(t)  (acceleration)"   # ẍ(t)
    VEL_LABEL = "\u1E8B(t)  (velocity)"        # ẋ(t)
    POS_LABEL = "x(t)  (position)"
    put_label(screen, acc_rect, ACC_LABEL, ACC_COLOR)
    put_label(screen, vel_rect, VEL_LABEL, VEL_COLOR)
    put_label(screen, pos_rect, POS_LABEL, POS_COLOR)

    # time cursor across all three panels (get x-pixels)
    cx_acc = draw_time_cursor(screen, acc_rect, t0, t1, t_cur)
    cx_vel = draw_time_cursor(screen, vel_rect, t0, t1, t_cur)
    cx_pos = draw_time_cursor(screen, pos_rect, t0, t1, t_cur)

    # moving dot on the t=0 vertical axis in position panel
    ypix = pos_value_to_ypix(pos_rect, X[idx], y_min, y_max)
    draw_position_dot_on_axis(screen, pos_rect, X, idx, y_min, y_max, axis_xpix)

    # connect dot (axis_xpix, ypix) to cursor (cx_pos, ypix) on x(t)
    pygame.draw.line(screen, CURSOR_DOT, (axis_xpix, ypix), (cx_pos, ypix), 2)
    pygame.draw.circle(screen, CURSOR_DOT, (cx_pos, ypix), 5)  # node at intersection

    # status
    symbol = "Playing:" if is_playing else "Paused:"
    status_text = f"{symbol}  t={t_cur:.2f}   x={X[idx]:.6g}   \u1E8B={V[idx]:.6g}   x\u0308={A[idx]:.6g}"
    put_status(screen, status_text)    
    #put_status(screen, f"t={t_cur:.2f}   x={X[idx]:.6g}   \u1E8B={V[idx]:.6g}   x\u0308={A[idx]:.6g}")

    pygame.display.flip()

# =========  Tk events  =========
def Exit():
    try:
        master.destroy()
    except Exception:
        pass
    pygame.quit()
    raise SystemExit

def replot():
    global last_T, last_A, last_V, last_X, last_t0, last_t1, play_idx, is_playing
    expr = entry_acc.get().strip()
    try:
        a = make_accel(expr)
    except Exception as e:
        screen.fill(BG)
        for r in panel_rects():
            draw_panel_border(screen, r)
        put_status(screen, f"Error in acceleration expression: {e}")
        pygame.display.flip()
        return

    try:
        t0 = float(entry_t0.get().strip())
        t1 = float(entry_t1.get().strip())
        dt = float(entry_dt.get().strip())
        x0 = float(entry_x0.get().strip())
        v0 = float(entry_v0.get().strip())
    except Exception:
        screen.fill(BG)
        for r in panel_rects():
            draw_panel_border(screen, r)
        put_status(screen, "Error: t0, t1, dt, x0, ẋ0 must be numbers.")
        pygame.display.flip()
        return

    try:
        T, A, V, X = integrate_motion(a, t0, t1, dt, x0, v0)
    except Exception as e:
        screen.fill(BG)
        for r in panel_rects():
            draw_panel_border(screen, r)
        put_status(screen, f"Integration error: {e}")
        pygame.display.flip()
        return

    # store and reset playback
    last_T, last_A, last_V, last_X = T, A, V, X
    last_t0, last_t1 = t0, t1
    play_idx = 0
    is_playing = False
    play_btn.config(text="Play")

    redraw_from_last(cursor_idx=play_idx)

def toggle_play():
    global is_playing
    is_playing = not is_playing
    play_btn.config(text=("Pause" if is_playing else "Play"))
    # immediate redraw to update status icon/text
    redraw_from_last(cursor_idx=play_idx)

def reset_play():
    global play_idx, is_playing
    play_idx = 0
    is_playing = False
    play_btn.config(text="Play")
    redraw_from_last(cursor_idx=play_idx)

# ======  Pygame  ======
pygame.init()
pygame.freetype.init()
pygame.font.init()
screen = pygame.display.set_mode(WINSIZE)
pygame.display.set_caption("Integration Visualizer (ẍ(t) -> ẋ(t) -> x(t))")
screen.fill(BG)
pygame.display.flip()

# global label font
LABEL_FONT = pygame.freetype.SysFont("FreeSerif", 22) 
STATUS_FONT = pygame.freetype.SysFont("FreeSerif", 20)
TICK_FONT = pygame.freetype.SysFont("FreeSerif", 16) 

# =========  Tk window  =========
master = Tk()
master.title('Controls: ẍ(t) integration')
master.geometry("760x260")  # ensure buttons are visible
master.minsize(720, 240)

# Row 1: ẍ(t)
row1 = Frame(master)
row1.pack(fill=X, padx=10, pady=(8, 4))
Label(row1, text="Acceleration ẍ(t) =").pack(side=LEFT)
entry_acc = Entry(row1, width=48)
entry_acc.insert(0, "(2<=t)*(t<3)*2 + (3<=t)*(t<4)*(-3)")
entry_acc.pack(side=LEFT, padx=6)
btn_plot = Button(row1, text="Compute & Plot", command=replot)
btn_plot.pack(side=LEFT, padx=6)

# Row 2: initial conditions
row2 = Frame(master)
row2.pack(fill=X, padx=10, pady=(2, 4))
Label(row2, text="Initial Conditions: x(t₀)=").pack(side=LEFT)
entry_x0 = Entry(row2, width=10)
entry_x0.insert(0, "0")
entry_x0.pack(side=LEFT, padx=(2, 10))
Label(row2, text="ẋ(t₀)=").pack(side=LEFT)
entry_v0 = Entry(row2, width=10)
entry_v0.insert(0, "1")
entry_v0.pack(side=LEFT, padx=(2, 10))

# Row 3: time window & dt
row3 = Frame(master)
row3.pack(fill=X, padx=10, pady=(2, 4))
Label(row3, text="Domain: t₀=").pack(side=LEFT)
entry_t0 = Entry(row3, width=8)
entry_t0.insert(0, "0")
entry_t0.pack(side=LEFT, padx=(2, 10))
Label(row3, text="t₁=").pack(side=LEFT)
entry_t1 = Entry(row3, width=8)
entry_t1.insert(0, "10")
entry_t1.pack(side=LEFT, padx=(2, 10))
Label(row3, text="Integration Step: dt=").pack(side=LEFT)
entry_dt = Entry(row3, width=8)
entry_dt.insert(0, "0.01")
entry_dt.pack(side=LEFT, padx=(2, 10))

# Row 4: playback controls
row4 = Frame(master)
row4.pack(fill=X, padx=10, pady=(2, 8))
play_btn = Button(row4, text="Play", command=toggle_play)
play_btn.pack(side=LEFT, padx=(0, 6))
Button(row4, text="Reset", command=reset_play).pack(side=LEFT, padx=6)

# Keyboard shortcuts
master.bind("<space>", lambda e: toggle_play())
master.bind("r",       lambda e: reset_play())

# Shortcuts for entries
entry_acc.bind("<Return>", lambda e: replot())
entry_t0.bind("<Return>", lambda e: replot())
entry_t1.bind("<Return>", lambda e: replot())
entry_dt.bind("<Return>", lambda e: replot())
entry_x0.bind("<Return>", lambda e: replot())
entry_v0.bind("<Return>", lambda e: replot())

# Initial draw
replot()

# -------------  Main loops  -------------
def pump_pygame():
    global play_idx
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            Exit()

    # advance playback only when playing
    if is_playing and last_T:
        if len(last_T) >= 2:
            dT = last_T[1] - last_T[0]
            steps = max(1, int(0.03 / max(dT, 1e-9)))  # ~0.03s per frame
        else:
            steps = 1
        play_idx = min(play_idx + steps, len(last_T) - 1)
        redraw_from_last(cursor_idx=play_idx)
        # stop automatically at the end
        if play_idx >= len(last_T) - 1:
            toggle_play()

    master.after(16, pump_pygame)  # ~60 Hz

pump_pygame()
master.mainloop()
