#!/usr/bin/env python
#

"""
ec2launch: Launch Amazon AWS EC2 instance
"""

from __future__ import absolute_import, print_function

import collections
import json
import os
import random
import re
import sys
import time

import boto
import ec2common

try:
    import gterm
except ImportError:
    import graphterm.bin.gterm as gterm

ssh_dir = os.path.expanduser("~/.ssh")
key_name = "ec2key"
gterm_user = "ubuntu"
gterm_dir = "~"+gterm_user+"/.graphterm"

aws_url = "https://console.aws.amazon.com/ec2"
gdev_url = "https://console.developers.google.com"

usage = "usage: %prog [-h ... ] <instance-tag|full-domain-name>"

form_parser = gterm.FormParser(usage=usage, title="Create Amazon EC2 instance with hostname (e.g., tagname OR tagname.domain): ", command="ec2launch -f")
form_parser.add_argument(label="", help="Instance hostname")
form_parser.add_option("type", ("m3.medium", "m3.large", "c3.large", "r3.large", "r3.xlarge", "m1.small", "t1.micro"), help="Instance type")
form_parser.add_option("ami", "ami-2f8f9246", help="Instance OS (default: Ubuntu 12.04LTS amd64, us-east-1, rel. 20140408)")
form_parser.add_option("gmail_addr", "", help="Full gmail address, user@gmail.com, for use in Google authentication")
form_parser.add_option("auth_type", ("singleuser", "none", "name", "multiuser"), help="Authentication type (singleuser/none/name/multiuser)")
form_parser.add_option("https", False, help="Use https for security")
form_parser.add_option("pylab", False, help="Install PyLab")
form_parser.add_option("netcdf", False, help="Install NetCDF")
form_parser.add_option("R", False, help="Install R")
form_parser.add_option("allow_embed", False, help="Allow cross-domain iframe embedding")
form_parser.add_option("allow_share", False, help="Allow sharing between users")
form_parser.add_option("notebook_server", False, help="Allow password-protected access to IPython Notebook server")
form_parser.add_option("logging", False, help="Enable logging")
form_parser.add_option("oshell", False, help="Enable OTrace shell")
form_parser.add_option("screen", False, help="Run server/host using GNU screen")
form_parser.add_option("other_opts", "", help="Other options for gtermserver (space-separated)")
form_parser.add_option("install", "", help="Tar file to install on startup (with optional '; command' suffix)")
form_parser.add_option("copy_files", "", help="List of files to copy (space separated, e.g., .twitter_auth)")
form_parser.add_option("connect_to_gterm", "", help="Full domain name of GraphTerm server to connect to (for hosts)")
form_parser.add_option("dry_run", False, help="Dry run")
form_parser.add_option("verbose", False, help="Verbose")
##form_parser.add_option("wildcard", False, help="Enable wildcard subdomains")

form_parser.add_option("form", False, help="Force form display", raw=True)
form_parser.add_option("fullpage", False, short="f", help="Fullpage display", raw=True)
form_parser.add_option("text", False, short="t", help="Text only", raw=True)

(options, args) = form_parser.parse_args()

if not gterm.Cookie or not sys.stdout.isatty():
    options.text = True

if not args or options.form:
    if options.text:
        print(form_parser.get_usage(), file=sys.stderr)
        sys.exit(1)
    gterm.write_form(form_parser.create_form(prefill=(options, args) if options.form else None), command="ec2launch -f")
    sys.exit(1)

gclient_id, gclient_secret = "", ""
if options.gmail_addr:
    print("", file=sys.stderr)
    oauth_creds = gterm.read_oauth()
    if oauth_creds and "google_oauth" in oauth_creds:
        ack = gterm.get_input("Use Google OAuth credentials from ~/.graphterm/%s? (y/n): " % gterm.APP_OAUTH_FILENAME)
        if ack.strip().lower().startswith("y"):
            gclient_id = oauth_creds["google_oauth"]["key"]
            gclient_secret = oauth_creds["google_oauth"]["secret"]
        
    if not gclient_id or not gclient_secret:
        print("To use Google Authentication for %s, please enter Project Credentials for web application (or blank for delayed setup)." % options.gmail_addr, file=sys.stderr)
        print("  You can create a new project and Client ID at "+gdev_url, file=sys.stderr)
        gclient_id = gterm.get_input("  Client ID: ")
        gclient_secret = gterm.get_input("  Client Secret: ")
    print("", file=sys.stderr)

instance_tag = args[0]

if not re.match(r"^[a-z][a-z0-9\.\-]*$", instance_tag):
    raise Exception("Invalid characters in instance name: "+instance_tag)

instance_name, sep, instance_domain = instance_tag.partition(".")

key_file = os.path.join(ssh_dir, key_name+".pem")

if not os.path.exists(ssh_dir):
    print("ec2launch: %s directory not found!" % ssh_dir, file=sys.stderr)
    sys.exit(1)

install_file, install_cmd, install_basename, install_baseroot, unarch_cmd = "", "", "", "", ""
copy_files = []
if options.install:
    install_file, sep, install_cmd = options.install.partition(",")
    if not os.path.isfile(install_file):
        raise Exception("Install file "+install_file+" not found")
    install_basename = os.path.basename(install_file)
    if install_cmd != "pip":
        install_baseroot, ext = os.path.splitext(install_basename)
        if ext == ".gz":
            install_baseroot, ext = os.path.splitext(install_baseroot)
            if ext != ".tar":
                raise Exception("Invalid extension for install file: "+install_file)
            unarch_cmd = "tar zxf"
        elif ext == ".tgz":
            unarch_cmd = "tar zxf"
        elif ext == ".tar":
            unarch_cmd = "tar xf"
        elif ext == ".zip":
            unarch_cmd = "unzip"
        else:
            raise Exception("Invalid extension for install file: "+install_file)
    copy_files.append(install_file)

if options.copy_files:
    copy_files += options.copy_files.split()

auth_type = options.auth_type.strip()
allow_host_connect = auth_type in ("singleuser", "multiuser")

GroupRule = collections.namedtuple("GroupRule",
                                   ["ip_protocol", "from_port", "to_port", "cidr_ip", "src_group_name"])

host_group = "gtermhost"
server_group = "gtermserver"
nb_server_group = "gtermnbserver"

SecurityGroups = [
    (host_group, "GraphTerm host group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        ]),
    (server_group, "GraphTerm server group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        GroupRule("tcp", "80", "80", "0.0.0.0/0", None),
        GroupRule("tcp", "81", "81", "0.0.0.0/0", None),
        GroupRule("tcp", "443", "443", "0.0.0.0/0", None),
        GroupRule("tcp", "444", "444", "0.0.0.0/0", None),
        GroupRule("tcp", "8888", "8888", "0.0.0.0/0", None),
        GroupRule("tcp", "8899", "8899", "0.0.0.0/0", None if allow_host_connect else host_group),
        GroupRule("tcp", "8900", "8900", "0.0.0.0/0", None),
        GroupRule("tcp", "8901", "8901", "0.0.0.0/0", None),
        ]),
    (nb_server_group, "GraphTerm notebook server group", [
        GroupRule("tcp", "22", "22", "0.0.0.0/0", None),
        GroupRule("tcp", "80", "80", "0.0.0.0/0", None),
        GroupRule("tcp", "81", "81", "0.0.0.0/0", None),
        GroupRule("tcp", "443", "443", "0.0.0.0/0", None),
        GroupRule("tcp", "444", "444", "0.0.0.0/0", None),
        GroupRule("tcp", "8899", "8899", "0.0.0.0/0", None if allow_host_connect else host_group),
        GroupRule("tcp", "8900", "8900", "0.0.0.0/0", None),
        GroupRule("tcp", "8901", "8901", "0.0.0.0/0", None),
        GroupRule("tcp", "10000", "12000", "0.0.0.0/0", None),
        ]),
    ]

setup_file = "SETUP_STATUS"
setup_errfile = "SETUP_ERROR"

startup_commands = ["#!/bin/bash",
"AWS_PUBLIC_HOST=`curl http://169.254.169.254/latest/meta-data/public-hostname`",
"AWS_PUBLIC_IP=`curl http://169.254.169.254/latest/meta-data/public-ipv4`",
"AWS_LOCAL_IP=`curl http://169.254.169.254/latest/meta-data/local-ipv4`",
"cat > /root/ec2launch.sh << EOF",
"#!/bin/bash",
"set -e -x",
"apt-get update && apt-get upgrade -qy",
"apt-get install -qy build-essential emacs python-dev python-setuptools python-pip zip",
"pip install tornado",
]

install_gterm = not install_basename.startswith("graphterm")

if install_gterm:
    startup_commands += ["pip install graphterm"]

if options.pylab:
    startup_commands += ["apt-get install -qy python-numpy python-scipy python-matplotlib python-mpltoolkits.basemap python-scientific python-pandas python-jinja2 ipython"]
    startup_commands += ["pip install pil qrcode"]
    startup_commands += ["pip install --upgrade ipython pyzmq"]

if options.netcdf:
    startup_commands += ["apt-get install -qy libnetcdf-dev netcdf-bin libhdf5-serial-dev"]
    startup_commands += ["pip install netCDF4"]

if options.R:
    startup_commands += ["apt-get install -qy r-base libcurl4-openssl-dev libcairo2-dev libxt-dev"]

command_name = "gtermhost" if options.connect_to_gterm else "gtermserver"
command_line = "sudo -u "+gterm_user+" -i"
if options.screen:
    command_line += " screen -d -m"
    if options.logging:
        command_line += " -L"
    command_line += " -S " + command_name + " " + command_name
else:
    command_line += " " + command_name + " --daemon=start"

if options.https:
    command_line += " --https"
if options.allow_embed:
    command_line += " --allow_embed"
if options.allow_share:
    command_line += " --allow_share"
if options.notebook_server:
    command_line += " --nb_server"
if options.other_opts:
    command_line += " " + options.other_opts.strip()
if options.logging:
    command_line += " --logging"
if options.oshell:
    command_line += " --oshell"
    if options.screen:
        command_line += " --oshell_input"

init_commands = []
fwall_commands = []

if options.connect_to_gterm:
    security_groups = [host_group]
    server_name, sep, server_domain = options.connect_to_gterm.partition(".")
    if not server_domain:
        raise Exception("Must specify fully qualified domain name to connect to GraphTerm server")
    host_name = instance_name if instance_domain == server_domain else instance_tag
    command_line += " --server_addr="+ options.connect_to_gterm+" "+host_name
    init_commands = [ command_line ]
else:
    security_groups = [nb_server_group if options.notebook_server else server_group]

    external_port = 443 if options.https else 80
    public_host = instance_tag if instance_domain else "${AWS_PUBLIC_HOST}"
    local_ip = "${AWS_LOCAL_IP}"
    for j in (0,1):
        fwall_commands += ["iptables -t nat -A PREROUTING -p tcp --dport %d -j REDIRECT --to %d" % (external_port+j, gterm.DEFAULT_HTTP_PORT+j),
                           "iptables -t nat -A OUTPUT --dst %s -p tcp --dport %d -j DNAT --to-destination %s:%d" % (local_ip, external_port+j, local_ip, gterm.DEFAULT_HTTP_PORT+j)]
    command_line += " --auth_type="+options.auth_type
    if auth_type == "multiuser":
        command_line += " --user_setup=auto --super_users="+gterm_user
    command_line += " --host=%s --external_port=%d" % (public_host, external_port)
    ##if options.wildcard:
    ##    command_line += " --blob_host=wildcard"
    init_commands = [ ("" if install_gterm else "#")+command_line ]

startup_commands += ["chmod go-rwx ~"+gterm_user,
                     "echo set -o ignoreeof >> /etc/bash.bashrc",
                     "sudo -u "+gterm_user+" -i mkdir -p "+gterm_dir,
                     "sudo -u "+gterm_user+" -i chmod go-rwx "+gterm_dir]

if options.gmail_addr:
    gauth_json = '{"google_oauth": {"key": "%s", "secret": "%s"}}' % (gclient_id, gclient_secret)
    startup_commands += ["sudo -u "+gterm_user+" -i touch "+gterm_dir+"/gterm_email.txt"]
    startup_commands += ["sudo -u "+gterm_user+" -i touch "+gterm_dir+"/gterm_oauth.json"]
    startup_commands += ["sudo -u "+gterm_user+" -i echo "+options.gmail_addr+" > "+gterm_dir+"/gterm_email.txt"]
    startup_commands += ["sudo -u "+gterm_user+" -i echo '"+gauth_json+"' > "+gterm_dir+"/gterm_oauth.json"]

if fwall_commands:
    startup_commands += fwall_commands
    for cmd in fwall_commands: 
       startup_commands += [ 'echo "'+cmd+'" >> /etc/init.d/firewall' ]
    startup_commands += [ "chmod +x /etc/init.d/firewall",
                          "update-rc.d firewall defaults"]

if init_commands:
    startup_commands += init_commands
    for cmd in init_commands: 
       startup_commands += [ 'echo "'+cmd+'" >> /etc/init.d/graphterm' ]
    startup_commands += [ "chmod +x /etc/init.d/graphterm",
                          "update-rc.d graphterm defaults"]

startup_commands += ["echo ''"]
startup_commands += ["echo WORKAROUND: Doing 'pip uninstall dap' to eliminate dap module loading errors for Ubuntu 12.04LTS"]
startup_commands += ["yes | pip uninstall dap || true"]
startup_commands += ["echo ec2launch SETUP COMPLETED"]

startup_commands += ["EOF",
                     "chmod go-rwx /root/ec2launch.sh  # Hide secrets etc.",
                     "/bin/bash /root/ec2launch.sh >/root/ec2launch.log 2>&1",
                     "if [ $? -ne 0 ]; then",
                     "  sudo -u "+gterm_user+" -i touch ~" + gterm_user + "/" + setup_errfile,
                     "fi",
                     "sudo -u "+gterm_user+" -i touch ~" + gterm_user + "/" + setup_file,
                     "tail -10 /root/ec2launch.log > ~" + gterm_user + "/" + setup_file,
                     ]
startup_script = "\n".join(startup_commands) + "\n"

if options.verbose:
    print("Startup_commands:\n   "+ "\n   ".join(startup_commands) + "\n\n", file=sys.stderr)

# Connect to EC2
ec2 = boto.connect_ec2()

route53conn = None
hosted_zone = None
nameservers = ""
if instance_domain:
    route53conn = ec2common.Route53Connection()
    hosted_zone = ec2common.get_hosted_zone(route53conn, instance_domain)
    if not hosted_zone:
        hosted_zone = ec2common.create_hosted_zone(route53conn, instance_domain)
        nameservers = ec2common.get_nameservers(route53conn, instance_domain)

def get_security_group(ec2conn, group_name):
    groups = [x for x in ec2conn.get_all_security_groups() if x.name == group_name]
    return groups[0] if groups else None

def create_or_update_security_group(ec2conn, group_name, description="", rules=[]):
    """Create (or update) security group"""
    group = get_security_group(ec2conn, group_name)
    new_rules = rules[:]
    if group:
        # Group already exists
        for rule in group.rules:
            # Check each rule
            cidr_ip = rule.grants[0].cidr_ip if rule.grants[0].cidr_ip else "0.0.0.0/0"
            src_group_name = None if rule.grants[0].cidr_ip else rule.grants[0].name
            old_rule = GroupRule(rule.ip_protocol,
                                 rule.from_port,
                                 rule.to_port,
                                 cidr_ip,
                                 src_group_name)
            if old_rule in new_rules:
                # Old rule still valid
                new_rules.remove(old_rule)
            else:
                # Old rule no longer valid
                group.revoke(ip_protocol=rule.ip_protocol,
                             from_port=rule.from_port,
                             to_port=rule.to_port,
                             cidr_ip=cidr_ip,
                             src_group=get_security_group(ec2conn, src_group_name) if src_group_name else None)
    else:
        group = ec2conn.create_security_group(group_name, description or group_name)
                                        

    for rule in new_rules:
        # Create new rules
        if rule.src_group_name:
            src_group = get_security_group(ec2conn, rule.src_group_name)
            if not src_group:
                raise Exception("Source group %s not found" % rule.src_group_name)
        else:
            src_group = None

        try:
            group.authorize(ip_protocol=rule.ip_protocol,
                            from_port=rule.from_port,
                            to_port=rule.to_port,
                            cidr_ip=rule.cidr_ip,
                src_group=src_group)
        except Exception as excp:
            if options.verbose:
                print("Rule creation error: "+str(excp)+"\n\n", file=sys.stderr)

    return group

# Create key pair, if needed
if not os.path.exists(key_file):
    key_pair = ec2.create_key_pair(key_name)
    key_pair.save(ssh_dir)
    os.chmod(key_file, 0o600)
    print("Created SSH key file", key_file, file=sys.stderr)

wait_cmd = "while [ ! -f "+setup_file+" ]; do sleep 1; done; echo "+setup_file+"; cat "+setup_file
if copy_files and options.install:
    wait_cmd += "; if [ ! -f "+setup_errfile+" ]; then :"
    if unarch_cmd:
        wait_cmd += "; "+unarch_cmd+" "+install_basename

    if install_cmd == "pip":
        wait_cmd += "; sudo pip install "+install_basename+" > install.log"
        if not install_gterm:
            wait_cmd += "; "+command_line
    elif install_cmd:
        wait_cmd += "; "+install_cmd

    wait_cmd += "; fi"

if options.dry_run:
    print("run_instances:", dict(image_id=options.ami,
                                                instance_type=options.type,
                                                key_name=key_name,
                                                security_groups=security_groups), file=sys.stderr)
    print("STARTUP_SCRIPT:\n", startup_script, file=sys.stderr)
    print("WAIT_CMD:\n", wait_cmd, file=sys.stderr)
    sys.exit(1)

# Create security groups as needed
for group_name, description, rules in SecurityGroups:
    create_or_update_security_group(ec2, group_name, description=description, rules=rules)

# Launch instance
reservation = ec2.run_instances(image_id=options.ami,
                                instance_type=options.type,
                                key_name=key_name,
                                security_groups=security_groups,
                                user_data=startup_script)
instance = reservation.instances[0]

# Wait for instance to start running
Status_template =  """<em>Creating instance</em> <b>%s</b>: status=<b>%s</b> (waiting %ds)"""

status = instance.update()
start_time = time.time()

if options.text:
    print("Waiting for instance to be created...", file=sys.stderr)
    
while status == "pending":
    timesec = int(time.time() - start_time)
    if not options.text:
        gterm.write_pagelet(Status_template % (instance_tag, status, timesec), overwrite=True, display="fullpage")
    time.sleep(3)
    status = instance.update()

if status != "running":
    print("ec2launch: ERROR Failed to launch instance: %s" % status, file=sys.stderr)
    sys.exit(1)

# Tag instance
instance.add_tag(instance_tag)

instance_id = reservation.id
instance_obj = None
all_instances = ec2.get_all_instances()
for r in all_instances:
    if r.id == instance_id:
        instance_obj = r.instances[0]
        break

if not instance_obj:
    print("ec2launch: ERROR Unable to find launched instance: %s" % status, file=sys.stderr)
    sys.exit(1)

public_dns_name = instance_obj.public_dns_name
public_ip_addr = instance_obj.ip_address
web_host = instance_tag if instance_domain else public_dns_name
web_url = ("https" if options.https else "http") + "://" + web_host

if hosted_zone:
    # Create new CNAME entry pointing to instance public DNS
    ec2common.cname(route53conn, hosted_zone, instance_tag, public_dns_name)
    ##if options.wildcard:
    ##    ec2common.cname(route53conn, hosted_zone, "*."+instance_tag, public_dns_name)

print("", file=sys.stderr)
print("Created EC2 instance %s: id=%s, ip=%s, domain=%s" % (instance_tag, instance_id, public_ip_addr, public_dns_name), file=sys.stderr)
print("Setup commands are saved in /root/ec2launch.sh and init commands in /etc/init.d/graphterm", file=sys.stderr)
print("To check setup progress, type:", file=sys.stderr)
print("   ec2ssh "+gterm_user+"@"+public_dns_name+" sudo tail -f /root/ec2launch.log", file=sys.stderr)

if install_gterm:
    print("On successful instance creation, use the following URL to access server:", file=sys.stderr)
    print("   "+web_url, file=sys.stderr)

ssh_cmd_args = ["ssh", "-i", key_file,
                "-o", "StrictHostKeyChecking=no", gterm_user+"@"+public_dns_name]

if nameservers:
    print("***NOTE*** Please add the following nameservers for domain "+instance_domain, file=sys.stderr)
    print("           ", nameservers, file=sys.stderr)
    print("", file=sys.stderr)

if options.gmail_addr:
    print("", file=sys.stderr)
    print("***NOTE*** For Google Authentication, update following URI settings for your web application at "+gdev_url, file=sys.stderr)
    print("  Authorized Javascript origins: "+web_url, file=sys.stderr)

    print("  Authorized Redirect URI: "+web_url+"/_gauth", file=sys.stderr)
    print("", file=sys.stderr)

print("Waiting several minutes for remote setup to complete...", file=sys.stderr)
time.sleep(45)

if copy_files:
    scp_cmd_args = ["scp", "-i", key_file, "-o", "StrictHostKeyChecking=no"] + copy_files + [gterm_user+"@"+public_dns_name+":"]
    print(" ".join(scp_cmd_args), file=sys.stderr)
    out, err = gterm.command_output(scp_cmd_args, timeout=60)
    if out:
        print(out)
    if err:
        print(err)

print(" ".join(ssh_cmd_args), file=sys.stderr)
print(wait_cmd, file=sys.stderr)
out, err = gterm.command_output(ssh_cmd_args + [wait_cmd], timeout=600)

if out:
    print(out)
if err:
    print(err, file=sys.stderr)
else:
    print("Remote setup completed.", file=sys.stderr)
    
