#!/usr/bin/env python3
import argparse, os, sys
import numpy as np
import matplotlib as mpl
mpl.use("Agg")
import matplotlib.pyplot as plt
from datetime import datetime

# ---------- Metadata per volume ----------
VOLS = {
    "pb3dv1": {
        "filename": "PLANUM_BOREUM_3D_V1_TIME.DAT",
        "rvs": (5401, 5401, 1334),
        "axis_start": (-1282750.0, -1282750.0, 100.0),
        "axis_interval": (475.0, 475.0, 0.0375),
    },
    "pb3dv2": {
        "filename": "PLANUM_BOREUM_3D_V2_TIME.DAT",
        "rvs": (5401, 5401, 1335),
        "axis_start": (-1282987.5, -1282987.5, 0.0),
        "axis_interval": (475.0, 475.0, 0.0375),
    },
    "pa3dv1": {
        "filename": "PLANUM_AUSTRALE_3D_V1_TIME.DAT",
        "rvs": (5475, 5475, 1334),
        "axis_start": (-1300075.0, -1300075.0, 0.0),
        "axis_interval": (475.0, 475.0, 0.0375),
    },
    "ed3dv2": {
        "filename": "EAST_DEUTERONILUS_MENSAE_3D_V2_TIME.DAT",
        "rvs": (1280, 1420, 3600),
        "axis_start": (-303762.5, -337012.5, 0.0),
        "axis_interval": (475.0, 475.0, 0.0375),
    },
}

def parseargs():
    parser = argparse.ArgumentParser(
        description="SHARAD_3D_PlaneSlicer 2.2 (fast memmap edition)"
    )
    parser.add_argument("n", type=int, help="Inline, crossline, or slice index (0-based)")
    parser.add_argument("vol", choices=VOLS.keys(), help="Volume name")
    parser.add_argument("dim", choices=["inline", "crossline", "slice"], help="Plane dimension")
    parser.add_argument(
        "-rvp", "--volume-path", default="/data/k/SHARAD/3D/PDS/mrosh_3001/data/",
        help="Directory containing the radar volume (defaults to the known path)",
    )
    parser.add_argument(
        "-o", "--outbase", default=None,
        help="Basename for output (defaults to derived from volume file)"
    )
    parser.add_argument(
        "-d", "--diag", action="store_true",
        help="Verbose timings/diagnostics"
    )
    args = parser.parse_args()

    meta = VOLS[args.vol]
    rvs = meta["rvs"]
    axis_start = meta["axis_start"]
    axis_interval = meta["axis_interval"]

    # Resolve file
    vol_dir = args.volume_path
    radar_volume_path = os.path.join(vol_dir, meta["filename"])
    if not os.path.isfile(radar_volume_path):
        radar_volume_path = os.path.join(vol_dir, meta["filename"].lower())
        if not radar_volume_path:
            print("[ERROR] Volume not found: {}".format(radar_volume_path))
            sys.exit(1)

    # Bounds check (0-based indexing)
    n = args.n
    if n < 0:
        print("[ERROR] Negative indices make no sense.")
        sys.exit(1)
    if args.dim in ("inline", "crossline") and n >= rvs[0]:  # inline/crossline count is rvs[0] or rvs[1]?
        # Careful: inline uses first dim (# inlines), crossline uses second dim.
        pass  # handled below with specific checks.

    if args.dim == "inline" and n >= rvs[0]:
        print("[ERROR] inline index n={n} out of range [0, {}]".format(rvs[0]-1))
        sys.exit(1)
    if args.dim == "crossline" and n >= rvs[1]:
        print("[ERROR] crossline index n={n} out of range [0, {}]".format(rvs[1]-1))
        sys.exit(1)
    if args.dim == "slice" and n >= rvs[2]:
        print("[ERROR] slice index n={n} out of range [0, {}]".format(rvs[2]-1))
        sys.exit(1)

    # Output basename
    if args.outbase is None:
        obase = os.path.splitext(os.path.basename(radar_volume_path))[0]
    else:
        obase = args.outbase

    return n, args.vol, args.dim, obase, radar_volume_path, rvs, axis_start, axis_interval, args.diag


def compute_extents(dim, rvs, axis_start, axis_interval):
    """
    Returns (L, R, B, T) and figsize for imshow extent/figure.
    """
    if dim in ("inline", "crossline"):
        # Horizontal axis = distance (km) across the *other* spatial dimension
        T = 0 * axis_interval[2] + axis_start[2]
        B = rvs[2] * axis_interval[2] + axis_start[2]
        if dim == "inline":
            L = axis_start[1]
            R = L + rvs[1] * axis_interval[1]
        else:  # crossline
            L = axis_start[0]
            R = L + rvs[0] * axis_interval[0]
        figsize = (12, 6)
    else:  # slice
        # Both axes are spatial
        B = axis_start[1]
        T = B + rvs[1] * axis_interval[1]
        L = axis_start[0]
        R = L + rvs[0] * axis_interval[0]
        figsize = (12, 12)
    return (L, R, B, T), figsize


def main():
    prog = "SHARAD_3D_PlaneSlicer"
    vers = "2.2"
    dtype = np.dtype("<f4")  # 4-byte little-endian float
    n, vol, dim, obase, radar_volume_path, rvs, axis_start, axis_interval, diag = parseargs()

    # Shape convention (matches your inline read):
    # data stored as [inline, crossline, slice] in C-order (inline-major contiguous blocks).
    shape = (rvs[0], rvs[1], rvs[2])

    # Use memmap: no Python loops, no tiny reads, OS pages what it needs.
    mm = np.memmap(radar_volume_path, mode="r", dtype=dtype, shape=shape, order="C")

    # Select plane
    tic = datetime.now()
    if dim == "inline":
        plane = mm[n, :, :].T  # [slice, crossline]
    elif dim == "crossline":
        plane = mm[:, n, :].T  # [slice, inline]
    else:  # "slice"
        plane = mm[:, :, n]    # [inline, crossline] (no transpose here)
    toc = datetime.now()
    if diag:
        ft = (toc - tic).total_seconds()
        print("[DIAG] slice extraction took {:.6f} s".format(ft))

    data = np.asarray(plane, dtype=np.float32, order="C")  # materialize just this 2D view

    # Shift so minimum is zero
    mn = data.min()
    if mn < 0:
        data = data - mn

    # dB scaling with safe floor to avoid -inf
    eps = 1e-30
    db = 20.0 * np.log10(np.maximum(np.abs(data), eps))

    # Plot
    (L, R, B, T), figsize = compute_extents(dim, rvs, axis_start, axis_interval)
    extents = [L / 1000.0, R / 1000.0, B, T]  # km on x

    ifile = "{}_{}_{}.png".format(obase, dim, n)
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    im = ax.imshow(db, cmap="gray", extent=extents, aspect="auto", vmin=-30, vmax=0)
    ax.set_title(ifile, fontsize=16)
    ax.set_xlabel("Distance relative to Center of Volume\n(km)")
    if dim == "slice":
        ax.set_ylabel("Distance relative to Center of Volume\n(km)")
        ax.invert_yaxis()
    else:
        ax.set_ylabel("Delay Time from Processing Datum\n(microseconds)")
    fig.colorbar(im, orientation="horizontal")
    plt.savefig(ifile, dpi=500)
    if diag:
        print("[DIAG] wrote {}".format(ifile))

if __name__ == "__main__":
    main()

