#!/usr/bin/python3

import sys, os, subprocess, time, signal

import threading

try:
    from argparse import ArgumentParser
    using_optparse = False
except ImportError:
    from optparse import OptionParser as ArgumentParser
    using_optparse = True

try:
    import psutil
    have_psutil = True
except ImportError:
    have_psutil = False

# ProcessLookupError was added in python 3.3
try:
    ProcessLookupError
except NameError:
    ProcessLookupError = OSError

VERSION='0.1'

def DEBUG(level, *msg, end='\n'):
    if level<=debuglevel:
        print(*msg, file=sys.stderr, end=end)
        sys.stderr.flush()


def children_pid_psutil(pid):
    "get the list of the children (and grand and grandgrand...) pids of a pid (psutil)"
    p = psutil.Process(pid)
    children = p.get_children(recursive=True)
    children = set(x.pid for x in children)
    return children


def children_pid_ps(pid):
    "get the list of the children (and grand and grandgrand...) pids of a pid (linux only, uses ps)"

    # get pid, parent_pid
    proc = subprocess.Popen('ps --no-headers -o pid,ppid ax', shell=True, stdout=subprocess.PIPE)
    pidppid = [x.split() for x in proc.communicate()[0].splitlines() if x]
    pidppid = [(int(x[0]), int(x[1])) for x in pidppid]
    children = set()
    while True:
        prev_len = len(children)
        for p, pp in pidppid:
            if pp == pid or pp in children:
                children.add(p)
        if len(children) == prev_len:
            break
    return children

if have_psutil:
    children_pid = children_pid_psutil
else:
    children_pid = children_pid_ps


def kill_children(pid, sig, raise_parent):
    "kill all the children of pid with signal, including the parent process"
    "if raise_parent is true, missing pid process will raise ProcessLookupError, to signal it does not exist anymore"
    children = children_pid(pid)
    for childpid in children:
        try:
            os.kill(childpid, sig)
        except ProcessLookupError:
            # if the child disappeared in the meantime
            pass
    # do the same with the main process
    if raise_parent:
        os.kill(pid, sig)
    else:
        try:
            os.kill(pid, sig)
        except ProcessLookupError:
            pass

def watcher(pid, run_period, sleep_period):
    while True:
        children = children_pid(pid)
        time.sleep(run_period)
        DEBUG(1, '$', end='\b')
        DEBUG(2, time.asctime(), 'STOP:', pid,  children)
        try:
            kill_children(pid, signal.SIGSTOP, raise_parent=True)
            time.sleep(sleep_period)
            DEBUG(1, '=', end='\b')
            DEBUG(2, time.asctime(), 'CONT:', pid, children)
        except ProcessLookupError:
            break
        finally:
            # be careful to leave the processes running if something bad happens
            kill_children(pid, signal.SIGCONT, raise_parent=False)


def run_subprocess(cmd, run_period, sleep_period):
    p = subprocess.Popen(cmd, shell=True)
    t = threading.Thread(target=watcher, args = (p.pid, run_period, sleep_period))
    t.daemon = True
    t.start()
    try:
        p.communicate()
    except KeyboardInterrupt:
        try:
            DEBUG(2, '\nCTRL C detected')
            # unstop all children
            children = children_pid(p.pid)
            kill_children(p.pid, signal.SIGCONT, raise_parent=False)
            # send SIGINT to the main process
            try:
                os.kill(p.pid, signal.SIGINT)
            except ProcessLookupError:
                pass
#            status = os.wait()[1]
#            DEBUG(2, 'Process ended with status', status)
#            sys.exit(status)
            time.sleep(0.5) # just in case, to give the main process time to clean up
        except KeyboardInterrupt:
            # if the user pressed CTRL C quickly twice
            DEBUG(2, 'CTRL C pressed multiple times, exiting')
            sys.exit(1)

def watch_pid(pid, run_period, sleep_period):
    try:
        watcher(pid, run_period, sleep_period)
    except KeyboardInterrupt:
        # catch ctrl C
        # unstop processes
        children = children_pid(pid)
        DEBUG(1, '\nCTRL C detected, ending...')
        kill_children(pid, signal.SIGCONT, raise_parent=False)
        # and finish, letting them run...
        return
    finally:
        # if anything bad happened, at least try to unstop processes
        children = children_pid(pid)
        kill_children(pid, signal.SIGCONT, raise_parent=False)


if __name__ == '__main__':

    parser = ArgumentParser(prog='verynice',
            description='Aggresively throttle processes.')
    if using_optparse:
        parser.add_argument = parser.add_option
        parser.parse_known_args = parser.parse_args


    parser.add_argument('-p', dest='pid', action='store',
                       type=int,
                       help='Be nice to this pid (and all the children)')

    parser.add_argument('-v', dest='debuglevel', action='count',
                       default = 0,
                       help='Be verbose (repeat for more)')

    parser.add_argument('-n', '--niceness', '--adjustment', dest='niceness', action='store',
                       type=str, default='1',
                       help='Niceness, e.g. 1, 5, 1/3, 0.4')

    if not using_optparse:
        parser.add_argument('--version', action='version',
                       version='%(prog)s '+VERSION)


    args, rest = parser.parse_known_args()

    global debuglevel
    debuglevel = args.debuglevel
    DEBUG(3, 'args:', str(args))
    DEBUG(3, 'optparse:', using_optparse)
    DEBUG(3, 'psutil available:', have_psutil)
    sleep_period = 1.0
    run_period = 1.0
    if '/' in args.niceness:
        sleep_period, run_period = args.niceness.split('/')
        sleep_period = float(sleep_period)
        run_period = float(run_period)
    else:
        sleep_period = float(args.niceness)

    DEBUG(2, 'debuglevel', debuglevel)
    DEBUG(2, 'sleep period:', sleep_period)
    DEBUG(2, 'run period:', run_period)
    DEBUG(2, 'pid:', args.pid)

    cmd = ' '.join(rest)
    DEBUG(1, 'CMD:', cmd)


    if args.pid and cmd:
        print ('You cannot combine -p and command.')
        print()
        parser.print_help()
        sys.exit(1)
    if args.pid == 1:
        print ('Cowardly refusing to suspend init(8) and its children.')
        sys.exit(1)
    if args.pid:
        watch_pid(args.pid, run_period, sleep_period)
    elif cmd:
        run_subprocess(cmd, run_period, sleep_period)
    else:
        print ('Usage: verynice [options] command')
        parser.print_help()

