import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# ==========================================
# 3D Ideal Gas (hard-sphere) in a box
# - elastic wall + particle collisions
# - plots energy distribution over time
# ==========================================

# -----------------------
# Parameters
# -----------------------
N = 1000                        # number of particles 120
L = 1.0                         # cubic box: [0,L]^3
radius = 1 / (20 * N**(1/3))    # sphere radius 0.015     
mass = 1.0                      # equal masses
dt = 0.0015                     # time step
steps_per_frame = 1             # physics steps between animation frames 6
step_max = 100                  # maximum number of steps 

# Initial velocity scale (sets initial "temperature-ish" scale)
v0 = 1.0

# Speed histogram
num_bins = 60
v_max_plot = 6.0     # x-axis max for speed plot

# Energy histogram
num_bins = 60
E_max_plot = 6.0     # x-axis max for energy plot

# Collision handling: to reduce repeated numerical "double hits"
pair_cooldown_steps = 2  # ignore same pair for this many physics steps

# -----------------------
# Initialization
# -----------------------
rng = np.random.default_rng(0)  # set seed for reproducibility; remove/None for fresh randomness

# Place particles with rejection sampling (avoid overlaps)
pos = np.zeros((N, 3), dtype=float)
for i in range(N):
    placed = False
    while not placed:
        trial = rng.uniform(radius, L - radius, size=3)
        if i == 0:
            pos[i] = trial
            placed = True
        else:
            d2 = np.sum((pos[:i] - trial) ** 2, axis=1)
            if np.all(d2 > (2 * radius) ** 2):
                pos[i] = trial
                placed = True

# Velocities ~ normal in each component (Maxwellian components)
vel = rng.normal(0.0, v0, size=(N, 3))

# Remove any net momentum (optional but nice)
vel -= vel.mean(axis=0, keepdims=True)

# Track recent collisions (pair -> remaining cooldown)
recent = {}  # dict[(i,j)] = cooldown_remaining

# Collision counters
wall_collision_count = 0
particle_collision_count = 0

# -----------------------
# Physics helpers
# -----------------------
def wall_collisions(pos, vel):
    """Elastic reflection from cube walls. Returns number of wall collision events this step."""
    events = 0
    for axis in range(3):
        low = pos[:, axis] < radius
        high = pos[:, axis] > (L - radius)

        if np.any(low):
            # count how many particles hit this wall this step
            events += int(np.count_nonzero(low))
            vel[low, axis] *= -1
            pos[low, axis] = radius

        if np.any(high):
            events += int(np.count_nonzero(high))
            vel[high, axis] *= -1
            pos[high, axis] = L - radius
    return events

def particle_collisions(pos, vel, recent):
    """
    Naive O(N^2) hard-sphere collisions for equal masses in 3D.
    Uses a small cooldown to reduce repeated numerical "double hits".
    Returns number of particle-pair collision events this step.
    """
    # decay cooldowns
    to_delete = []
    for key in recent:
        recent[key] -= 1
        if recent[key] <= 0:
            to_delete.append(key)
    for key in to_delete:
        del recent[key]

    events = 0
    for i in range(N):
        for j in range(i + 1, N):
            key = (i, j)
            if key in recent:
                continue

            dr = pos[i] - pos[j]
            dist2 = np.dot(dr, dr)
            min_dist = 2 * radius

            if dist2 < min_dist * min_dist:
                dv = vel[i] - vel[j]
                # Only collide if approaching
                if np.dot(dv, dr) < 0:
                    # equal-mass elastic collision:
                    # reflect relative velocity along line of centers
                    factor = np.dot(dv, dr) / dist2
                    impulse = factor * dr
                    vel[i] -= impulse
                    vel[j] += impulse

                    # positional correction to separate spheres
                    dist = np.sqrt(dist2)
                    if dist > 0:
                        overlap = min_dist - dist
                        corr = (overlap / 2.0) * (dr / dist)
                        pos[i] += corr
                        pos[j] -= corr

                    recent[key] = pair_cooldown_steps
                    events += 1
    return events

def step(pos, vel, recent):
    """Advance the system by one dt. Returns (wall_events, particle_events)."""
    pos += vel * dt
    wall_events = wall_collisions(pos, vel)
    particle_events = particle_collisions(pos, vel, recent)
    return pos, vel, wall_events, particle_events


# -----------------------
# Theory: 3D Maxwell speed distribution
# -----------------------
# f(v) = 4*pi * (m/(2*pi*kT))^(3/2) * v^2 * exp(-m v^2/(2 kT))
# We estimate kT from kinetic energy:
#   <E> = (3/2) kT,   and <E> = (1/2) m <v^2>
#   => kT = (m <v^2>)/3
def maxwell_speed_pdf_3d(v, kT_over_m):
    """
    Maxwell speed PDF in 3D expressed using kT/m = kT_over_m.
    f(v) = 4*pi * (1/(2*pi*kT/m))^(3/2) * v^2 * exp(-v^2/(2*kT/m))
    """
    v = np.asarray(v)
    a = kT_over_m
    pref = 4.0 * np.pi * (1.0 / (2.0 * np.pi * a)) ** 1.5
    return pref * v**2 * np.exp(-v**2 / (2.0 * a))

# -----------------------
# Theory: 3D energy distribution
# -----------------------
# For 3D translational DOF:
#   E = (1/2) m v^2
#   f(E) = (2/sqrt(pi)) * (1/(kT)^(3/2)) * sqrt(E) * exp(-E/kT),  E>=0
# and <E> = (3/2) kT  =>  kT = (2/3) <E>
def mb_energy_pdf_3d(E, kT):
    E = np.asarray(E)
    return (2.0 / np.sqrt(np.pi)) * (1.0 / (kT ** 1.5)) * np.sqrt(E) * np.exp(-E / kT)


# -----------------------
# Plot setup (2D view + energy histogram)
# -----------------------
fig = plt.figure(figsize=(10, 4))
gs = fig.add_gridspec(1, 2, width_ratios=[1, 1])

ax_shist = fig.add_subplot(gs[0, 0])
ax_ehist = fig.add_subplot(gs[0, 1])

# On-plot status text (steps + collisions + thermodynamic estimates)
status_text = ax_shist.text(0.02, 0.98, "", transform=ax_shist.transAxes, va="top")
# perhaps doesn't work with hist

# Speed histogram axis
ax_shist.set_title("Speed Distribution n="+str(N))
ax_shist.set_xlabel("Speed |v|")
ax_shist.set_ylabel("Proportion of Molecules")
ax_shist.set_xlim(0, v_max_plot)
#ax_shist.set_ylim(0, 1.1 * max_y if max_y > 0 else 1.0)
ax_shist.set_ylim(0, 0.8)

# Energy histogram axis
ax_ehist.set_title("Energy Distribution n="+str(N))
ax_ehist.set_xlabel(r"Energy  $E=\frac{1}{2} m v^2$")
ax_ehist.set_ylabel("Proportion of Molecules")
ax_ehist.set_xlim(0, E_max_plot)
#ax_ehist.set_ylim(0, 1.1 * max_y if max_y > 0 else 1.0)
ax_ehist.set_ylim(0, 0.8)

# Sum over time
speeds_sum = np.zeros(N)
E_sum = np.zeros(N)

step_counter = 0
frame_counter = 0

def update(frame):
    global pos, vel, recent, step_counter, frame_counter
    global wall_collision_count, particle_collision_count
    global speeds_sum, E_sum
    
    if step_counter >= step_max:
        ani.event_source.stop()
        return
    
    wall_events_total = 0
    particle_events_total = 0

    # advance physics
    for _ in range(steps_per_frame):
        pos, vel, w_ev, p_ev = step(pos, vel, recent)
        step_counter += 1
        wall_events_total += w_ev
        particle_events_total += p_ev

    wall_collision_count += wall_events_total
    particle_collision_count += particle_events_total

    frame_counter += 1

    # speeds
    speeds = np.linalg.norm(vel, axis=1)
    speeds_sum += speeds
    speeds_avg = [x / frame_counter for x in speeds_sum]
#    speeds_avg = speeds_sum / frame_counter 
#    print(particle_collision_count, frame_counter, step_counter, speeds[10], speeds_avg[10])
    
    # estimate kT/m from <v^2>/3
    v2_mean = np.mean(speeds**2)
    kT_over_m = v2_mean / 3.0

    # energies
    speeds2 = np.sum(vel * vel, axis=1)
    E = 0.5 * mass * speeds2
    E_sum += E
    E_avg = [x / frame_counter for x in E_sum]
#    E_avg = E_sum / frame_counter
#    print(particle_collision_count, frame_counter, step_counter, E[10], E_avg[10])

    # Estimate kT via <E> = (3/2) kT
    mean_E = np.mean(E)
    kT_est = (2.0 / 3.0) * mean_E

    # redraw speed histogram panel each frame
    ax_shist.cla()
    ax_shist.set_title("Speed Distribution n="+str(N))
    ax_shist.set_xlabel("Speed |v|")
    ax_shist.set_ylabel("Proportion of Molecules")
    ax_shist.set_xlim(0, v_max_plot)
    ax_shist.set_ylim(0, 0.8)

    counts, bins, _ = ax_shist.hist(
        speeds, bins=num_bins, range=(0, v_max_plot), density=True, alpha=1
    )

    # theory overlay
    v_grid = np.linspace(0, v_max_plot, 400)
    f_grid = maxwell_speed_pdf_3d(v_grid, kT_over_m)
    ax_shist.plot(v_grid, f_grid, lw=2)

    # rescale y to fit
#    max_y = max(np.max(counts) if len(counts) else 0.0, np.max(f_grid))
#    ax_shist.set_ylim(0, 1.1 * max_y if max_y > 0 else 1.0)
    ax_shist.set_ylim(0, 0.8)

    # redraw energy histogram panel each frame (simple + robust)
    ax_ehist.cla()
    ax_ehist.set_title("Energy Distribution n="+str(N))
    ax_ehist.set_xlabel(r"Energy  $E=\frac{1}{2} m v^2$")
    ax_ehist.set_ylabel("Proportion of Molecules")
    ax_ehist.set_xlim(0, E_max_plot)
    ax_ehist.set_ylim(0, 0.8)

    counts, bins, _ = ax_ehist.hist(
        E, bins=num_bins, range=(0, E_max_plot), density=True, alpha=1
    )

    # theory overlay
    E_grid = np.linspace(0, E_max_plot, 300)
    f_grid = mb_energy_pdf_3d(E_grid, kT_est)
    ax_ehist.plot(E_grid, f_grid, lw=2)

    # rescale y to fit
#    max_y = max(np.max(counts) if len(counts) else 0.0, np.max(f_grid))
#    ax_ehist.set_ylim(0, 1.1 * max_y if max_y > 0 else 1.0)
    ax_ehist.set_ylim(0, 0.8)

    # status text
    status_text.set_text(
        "steps: {}\n"
        "wall colls: {}\n"
        "pair colls: {}\n"
        "<v^2>: {:.3f}\n"
        "kT/m: {:.3f}".format(
            step_counter,
            wall_collision_count,
            particle_collision_count,
            v2_mean,
            kT_over_m,
        )
    )
    print(step_counter, particle_collision_count)
    return

ani = FuncAnimation(fig, update, frames=1000, interval=20, blit=False)
plt.tight_layout()
plt.show()
