#! /usr/bin/python3

import sys, os, subprocess, time, signal

import threading

try:
    from argparse import ArgumentParser
    using_optparse = False
except ImportError:
    from optparse import OptionParser as ArgumentParser
    using_optparse = True

try:
    import psutil
    have_psutil = True
except ImportError:
    have_psutil = False

# ProcessLookupError was added in python 3.3
try:
    ProcessLookupError
except NameError:
    ProcessLookupError = OSError

VERSION='0.3'

def DEBUG(level, *msg, end='\n'):
    if level<=debuglevel:
        print(*msg, file=sys.stderr, end=end)
        sys.stderr.flush()


def children_pid_psutil(pid):
    "get the list of the children (and grand and grandgrand...) pids of a pid (psutil)"
    try:
        p = psutil.Process(pid)
        children = p.get_children(recursive=True)
        children = set(x.pid for x in children)
    except psutil.NoSuchProcess:
        children = set()
    return children

def children_pid_ps(pid):
    "get the list of the children (and grand and grandgrand...) pids of a pid (linux only, uses ps)"

    # get pid, parent_pid
    proc = subprocess.Popen('ps --no-headers -o pid,ppid ax', shell=True, stdout=subprocess.PIPE)
    pidppid = [x.split() for x in proc.communicate()[0].splitlines() if x]
    pidppid = [(int(x[0]), int(x[1])) for x in pidppid]
    children = set()
    while True:
        prev_len = len(children)
        for p, pp in pidppid:
            if pp == pid or pp in children:
                children.add(p)
        if len(children) == prev_len:
            break
    return children

if have_psutil:
    children_pid = children_pid_psutil
else:
    children_pid = children_pid_ps


def kill_children(pid, sig, raise_parent):
    "kill all the children of pid with signal, including the parent process"
    "if raise_parent is true, missing pid process will raise ProcessLookupError, to signal it does not exist anymore"
    children = children_pid(pid)
    for childpid in children:
        try:
            os.kill(childpid, sig)
        except ProcessLookupError:
            # if the child disappeared in the meantime
            pass
    # do the same with the main process
    if raise_parent:
        os.kill(pid, sig)
    else:
        try:
            os.kill(pid, sig)
        except ProcessLookupError:
            pass


def renice_pid_psutil(pid, niceness, ioclass, ioclassdata):
    try:
        p = psutil.Process(pid)
        if niceness is not None:
            DEBUG(2, 'setting scheduler priority of {} to {}'.format(pid, niceness))
            p.set_nice(niceness)
        if ioclass is not None:
            DEBUG(2, 'setting i/o class of {} to {}, classdata {}'.format(pid, ioclass, ioclassdata))
            p.set_ionice(ioclass, ioclassdata)
    except ProcessLookupError:
        # if the process disappeared in the meantime
        pass


# set priorities using only python library modules
def renice_pid_internal(pid, niceness, ioclass, ioclassdata):
    # os.setpriority has been added in python3.3
    if 'setpriority' in os.__dict__:
        try:
            DEBUG(2, 'setting scheduler priority of {} to {}'.format(pid, niceness))
            os.setpriority(os.PRIO_PROCESS, pid, niceness)
        except ProcessLookupError:
            pass
    else:
        DEBUG(0, 'psutil not installed, renice ignored')
    DEBUG(0, 'psutil not installed, ionice ignored')

if have_psutil:
    renice_pid = renice_pid_psutil
else:
    renice_pid = renice_pid_internal

# recursively renice all the children of a given pid (and the pid itself)
def renice_children_pid(pid, niceness, ioclass, ioclassdata):
    children = children_pid(pid)
    for childpid in children | set([pid]):
        renice_pid(childpid, niceness, ioclass, ioclassdata)


def watcher(pid, run_period, sleep_period):
    while True:
        children = children_pid(pid)
        time.sleep(run_period)
        DEBUG(1, '$', end='\b')
        DEBUG(2, time.asctime(), 'STOP:', pid,  children)
        try:
            kill_children(pid, signal.SIGSTOP, raise_parent=True)
            time.sleep(sleep_period)
            DEBUG(1, '=', end='\b')
            DEBUG(2, time.asctime(), 'CONT:', pid, children)
        except ProcessLookupError:
            break
        finally:
            # be careful to leave the processes running if something bad happens
            kill_children(pid, signal.SIGCONT, raise_parent=False)


def run_subprocess(cmd, run_period, sleep_period, niceness, ioclass, ioclassdata):
    p = subprocess.Popen(cmd, shell=True)
    try:
        # renice only once, all the newly created children will inherit the nice values
        renice_children_pid(p.pid, niceness, ioclass, ioclassdata)

        t = threading.Thread(target=watcher, args = (p.pid, run_period, sleep_period))
        t.daemon = True
        t.start()
        try:
            p.communicate()
        except KeyboardInterrupt:
            try:
                DEBUG(2, '\nCTRL C detected')
                # unstop all children
                children = children_pid(p.pid)
                kill_children(p.pid, signal.SIGCONT, raise_parent=False)
                # send SIGINT to the main process
                try:
                    os.kill(p.pid, signal.SIGINT)
                except ProcessLookupError:
                    pass
    #            status = os.wait()[1]
    #            DEBUG(2, 'Process ended with status', status)
    #            sys.exit(status)
                time.sleep(0.5) # just in case, to give the main process time to clean up
            except KeyboardInterrupt:
                # if the user pressed CTRL C quickly twice
                DEBUG(2, '\nCTRL C pressed multiple times, killing processes and exiting')
                try:
                    os.kill(p.pid, signal.SIGINT)
                except ProcessLookupError:
                    pass
                sys.exit(1)
    finally:
        # if the spawned process still hangs around, ask it politely to finish
        try:
            os.kill(p.pid, signal.SIGTERM)
        except ProcessLookupError:
            pass


def watch_pid(pid, run_period, sleep_period):
    try:
        watcher(pid, run_period, sleep_period)
    except KeyboardInterrupt:
        # catch ctrl C
        # unstop processes
        try:
            children = children_pid(pid)
            DEBUG(1, '\nCTRL C detected, ending...')
            kill_children(pid, signal.SIGCONT, raise_parent=False)
        except KeyboardInterrupt:
             # if the user pressed CTRL C quickly twice
             DEBUG(2, '\nCTRL C pressed multiple times, exiting')
             sys.exit(1)
        # and finish, letting them run...
        return
    finally:
        # if anything bad happened, at least try to unstop processes
        try:
            children = children_pid(pid)
            kill_children(pid, signal.SIGCONT, raise_parent=False)
        except:
            pass


if __name__ == '__main__':

    parser = ArgumentParser(prog='verynice',
            description='Aggresively throttle processes.')
    if using_optparse:
        parser.add_argument = parser.add_option
        parser.parse_known_args = parser.parse_args


    parser.add_argument('-p', dest='pid', action='store',
                       type=int,
                       help='Be nice to this pid (and all the children)')

    parser.add_argument('-v', dest='debuglevel', action='count',
                       default = 0,
                       help='Be verbose (repeat for more)')

    parser.add_argument('-c', '--class', dest='ioclass', action='store',
                       type=int,
                       default = None,
                       help='I/O scheduling class number (the same as ionice): 0: none, 1: realtime, 2: best-effort, 3: idle')

    parser.add_argument('--ion', '--classdata', dest='ioclassdata', action='store',
                       type=int,
                       default=None,
                       help='I/O scheduling class data: 0-7 for realtime and best-effort classes')

    parser.add_argument('--nn', '--niceness', dest='niceness', action='store',
                       type=int,
                       default=None,
                       help='Scheduling niceness (the same as for nice)')

    parser.add_argument('-n', dest='veryniceness', action='store',
                       type=str, default='1',
                       help='Throttling value (sleep/run), e.g. 1, 5, 1/3, 0.4')

    parser.add_argument('-a', dest='all_nice', action='store_true',
                       help='Set maximum possible nice values (equivalent to ionice -c 3 nice -n 20)')

    if not using_optparse:
        parser.add_argument('--version', action='version',
                       version='%(prog)s '+VERSION)


    args, rest = parser.parse_known_args()

    global debuglevel
    debuglevel = args.debuglevel
    DEBUG(3, 'args:', str(args))
    DEBUG(3, 'optparse:', using_optparse)
    DEBUG(3, 'psutil available:', have_psutil)
    sleep_period = 1.0
    run_period = 1.0
    if '/' in args.veryniceness:
        sleep_period, run_period = args.veryniceness.split('/')
        sleep_period = float(sleep_period)
        run_period = float(run_period)
    else:
        sleep_period = float(args.veryniceness)


    # nice(1) emulation
    niceness = args.niceness

    # ionice emulation
    ioclass = args.ioclass
    ioclassdata = args.ioclassdata

    if niceness is None and args.all_nice:
        niceness = 20
    if ioclass is None and args.all_nice:
        ioclass = 3

    DEBUG(2, 'debuglevel', debuglevel)
    DEBUG(2, 'sleep period:', sleep_period)
    DEBUG(2, 'run period:', run_period)
    DEBUG(2, 'niceness:', niceness)
    DEBUG(2, 'ioclass:', ioclass)
    DEBUG(2, 'ioclassdata:', ioclassdata)
    DEBUG(2, 'pid:', args.pid)

    cmd = ' '.join(rest)
    DEBUG(1, 'CMD:', cmd)


    if args.pid and cmd:
        print ('You cannot combine -p and command.')
        print()
        parser.print_help()
        sys.exit(1)
    if args.pid == 1:
        print ('Cowardly refusing to suspend init(8) and its children.')
        sys.exit(1)
    if args.pid:
        renice_children_pid(args.pid, args.niceness, args.ioclass, args.ioclassdata)
        watch_pid(args.pid, run_period, sleep_period)
    elif cmd:
        run_subprocess(cmd, run_period, sleep_period, niceness, ioclass, ioclassdata)
    else:
        print ('usage: verynice [options] command')
        parser.print_help()

