#!/usr/bin/env python
#

"""
gncplot: 2D visualization of netCDF data

E.g., gncplot --variable=air --lon=0 --time=0 --lev=1000,0 --label --output=cs.png air.mon.ltm.nc
"""

from __future__ import absolute_import, print_function

import os
import sys
import termios
import tty

try:
    from collections import OrderedDict
except ImportError:
    from ordereddict import OrderedDict

import netCDF4
import gmatplot as gm
gm.setup()
import matplotlib.pyplot as plt

import pylab

try:
    import gterm
except ImportError:
    import graphterm.bin.gterm as gterm

def locate_value(coords, value):
    """Return index of nearest coordinate value"""
    ncoords = len(coords)
    if coords[-1] > coords[0]:
        jsel = 0 if value <= coords[0] else ncoords - 1
        for j in range(ncoords-1):
            if value > coords[j] and value <= coords[j+1]:
                jsel = j if (value - coords[j]) < (coords[j+1] - value) else j+1
    else:
        jsel = 0 if value >= coords[0] else ncoords - 1
        for j in range(ncoords-1):
            if value < coords[j] and value >= coords[j+1]:
                jsel = j+1 if (value - coords[j+1]) < (coords[j] - value) else j

    return jsel

COORD_NAMES = OrderedDict([("time", "time"), ("lev", "level"), ("lat", "latitude"), ("lon", "longiitude")])

usage = "usage: %prog [options] <netCDF file>"

if len(sys.argv) < 2:
    sys.exit("Usage: %s [options] netCDF_file" % sys.argv[0])

filepath = sys.argv[-1]

if not os.path.isfile(filepath):
    sys.exit("Unable to open file "+filepath)

ncfile = netCDF4.Dataset(filepath)

dimension_names = list(ncfile.dimensions.keys())
varnames = list(ncfile.variables.keys())

coord_varnames = {}
coord_stdnames = {}
coord_vals = {}

dim_varnames = {}
plot_varnames = []
for varname in varnames:
    if varname.endswith("bounds"):
        continue
    if varname not in dimension_names:
        plot_varnames.append(varname)
    else:
        dim_varnames[varname] = getattr(ncfile.variables[varname], "long_name", "")
        units = getattr(ncfile.variables[varname], "units", "").lower()
        if varname.lower() in ("lon", "longitude") or units.startswith("degrees_e"):
            coord_name = "lon"
        elif varname.lower() in ("lat", "latitude") or units.startswith("degrees_n"):
            coord_name = "lat"
        elif varname.lower() in ("altitude", "height", "sigma", "sigma_level", "hybrid_level", "eta", "eta_level", "layer", "lev", "level", "levels", "z") or varname.lower().startswith("press") or units in ("mb", "mbar", "millibar", "millibars", "bar", "hpa", "hectopascal", "hectopascals"):
            coord_name = "lev"
        elif units.startswith("second") or units.startswith("minute") or units.startswith("hour") or units.startswith("day") or units.startswith("year"):
            coord_name = "time"

        if coord_name:
            # These variables are not used yet!
            coord_varnames[coord_name] = varname
            coord_stdnames[varname] = coord_name
            coord_vals[coord_name] = ncfile.variables[varname][:]

plot_varnames.sort()

dim_varname_list = list(dim_varnames.keys())
dim_varname_list.sort()

form_parser = gterm.FormParser(usage=usage, title="Visualize netCDF file: ", command="gncplot -f")
form_parser.add_argument(default_value=filepath, label="", help="NetCDF file")
form_parser.add_option("variable", [""] + plot_varnames, help="Variable name")
if dim_varname_list:
    form_parser.add_option("scroll", [""] + dim_varname_list, help="Scroll coordinate")

for dname in dim_varname_list:
    cvals = ncfile.variables[dname][:]
    form_parser.add_option(dname.lower(), "%g,%g" % (min(cvals),max(cvals)), help=dim_varnames[dname]+" value[,value]")

form_parser.add_option("title", "", help="Plot title")
form_parser.add_option("label", False, help="Label contours")
form_parser.add_option("output", "", short="o", help="Write PNG image to file")
form_parser.add_option("verbose", False, help="Verbose diagnostic messages")

form_parser.add_option("form", False, help="Force form display", raw=True)
form_parser.add_option("fullscreen", False, short="f", help="Fullscreen display", raw=True)

(options, args) = form_parser.parse_args()

errmsg = ""

plotvar = options.variable
if plotvar:
    if plotvar not in ncfile.variables:
        sys.exit("Error: Variable %s not found in file" % plotvar)
    plot_obj = ncfile.variables[plotvar]

    if options.scroll:
        scroll_var = options.scroll
        scroll_vals = ncfile.variables[scroll_var][:]
    else:
        scroll_var = ""
        scroll_vals = []

    title_prefix = plotvar
    slice_obj = []
    kscroll = -1
    plot_dims = []
    plot_slice = []
    plot_lim = []
    for kdim, dname in enumerate(plot_obj.dimensions):
        cval_str = getattr(options, dname.lower(), "")
        if not cval_str:
            slice_obj.append(slice(None))
            plot_dims.append(dname)
            plot_slice.append(None)
            plot_lim.append(None)
        else:
            try:
                comps = cval_str.replace(" ","").split(",")
                if len(comps) > 2:
                    errmsg = "Too many coordinate values for %s" % dname
                else:
                    coords = ncfile.variables[dname][:]
                    cvals = [float(comp) for comp in comps]
                    jvals = [locate_value(coords, cval) for cval in cvals]
                    if len(jvals) == 1 or jvals[0] == jvals[1]:
                        slice_obj.append(jvals[0])
                        if dname == scroll_var:
                            kscroll = kdim
                        if dname != scroll_var or not options.fullscreen:
                            title_prefix += " %s=%s" % (dname, coords[jvals[0]])
                    else:
                        plot_lim.append(cvals)
                        slc = slice(min(jvals), max(jvals)+1)
                        slice_obj.append(slc)
                        plot_slice.append(slc)
                        plot_dims.append(dname)
            except Exception as excp:
                errmsg += "Invalid %s value %s: %s\n" % (dname, cval, excp)

    nplotdim = len(plot_dims)

    if not errmsg and nplotdim > 2:
        errmsg += "Too many dimensions to plot; slice out %d dimensions" % (nplotdim-2)

    if options.verbose:
        print("Slice=%s plot_dims=%s" % (slice_obj, plot_dims), file=sys.stderr)

if errmsg or not plotvar or options.form:
    if gterm.Cookie:
        if errmsg:
            gterm.wrap_write(errmsg, headers={"x_gterm_response": "error_message",
                                              "x_gterm_parameters": {}
                                              })
        gterm.write_form(form_parser.create_form(prefill=(options, args) if options.form else None, errmsg=errmsg), command="gncplot -f")
        sys.exit(1)
    else:
        if errmsg:
            sys.exit(errmsg)
        else:
            sys.exit("Unable to display interactive form; specify sufficient arguments to plot")

X_lim, Y_lim = None, None
if nplotdim:
    if plot_lim[-1]:
        X_lim = plot_lim[-1]
        X_coord = ncfile.variables[plot_dims[-1]][plot_slice[-1]]
    else:
        X_coord = ncfile.variables[plot_dims[-1]][:]

    if nplotdim > 1:
        if plot_lim[0]:
            Y_lim = plot_lim[0]
            Y_coord = ncfile.variables[plot_dims[0]][plot_slice[0]]
        else:
            Y_coord = ncfile.variables[plot_dims[0]][:]

def show_plot(slice_obj, title, output="", clabel=False, fullscreen=False):
    plot_data = plot_obj[slice_obj]
    if not nplotdim:
        print(title, plot_data)
        return

    gm.figure()
    if nplotdim == 1:
        plt.plot(X_coord, plot_data)
        if X_lim:
            plt.xlim(*X_lim)
        gm.show()
    else:
        plt.contourf(X_coord, Y_coord, plot_data, 10)
        plt.colorbar(orientation="horizontal")
        if options.label:
            cont = plt.contour(X_coord, Y_coord, plot_data, 10)
            plt.clabel(cont)
        if X_lim:
            plt.xlim(*X_lim)
        if Y_lim:
            plt.ylim(*Y_lim)

    plt.title(title)
    gm.show(fullscreen=fullscreen, overwrite=fullscreen, outfile=output)

if options.output:
    show_plot(slice_obj, options.title or title_prefix, output=options.output, clabel=options.label)
    sys.exit(0)

if not options.fullscreen or not scroll_var:
    show_plot(slice_obj, options.title or title_prefix, clabel=options.label, fullscreen=options.fullscreen)
    sys.exit(0)

Stdin_fd = sys.stdin.fileno()
Saved_settings = termios.tcgetattr(Stdin_fd)

first = True
jsel = slice_obj[kscroll] if kscroll >= 0 else 0
try:
    print("Scroll mode: SPC/'f' => forward, BSP/'b' => back, 'q' or ESCAPE => quit\n'p' => pause, 'r' => resume", file=sys.stderr)

    # Raw tty input without echo
    tty.setraw(Stdin_fd)
    while True:
        ch = sys.stdin.read(1)
        if ch == "\x03" or ch == "\x04" or ch == "\x1b" or ch == "q": # ^C/^D/ESC/q
            gterm.write_blank(display="fullpage", exit_page=True)
            sys.exit(0)

        if ch == "f" or ch == " ":
            jnew = jsel if first else jsel + 1
        elif ch == "b" or ch == "\x08" or ch == "\x7f": # Backspace
            jnew = jsel - 1
        elif ch == "r":
            jnew = jsel
        elif ch == "p":
            gterm.write_blank(display="fullpage", exit_page=True)
            continue
        else:
            continue

        jnew = max(0, min(jnew, len(scroll_vals)-1) )
        if jsel == jnew and ch != "r" and not first:
            continue

        first = False
        jsel = jnew
        tem_slice = slice_obj
        title = title_prefix
        if scroll_var:
            tem_slice[kscroll] = jsel
            title += " #%d (%s=%s)" % (jsel+1, scroll_var, scroll_vals[jsel])
        show_plot(tem_slice, title, fullscreen=True)

except KeyboardInterrupt:
    sys.exit(1)
finally:
    termios.tcsetattr(Stdin_fd, termios.TCSADRAIN, Saved_settings)


    
