#!/usr/bin/python3

from __future__ import print_function

import sys, os, subprocess, time, signal

import threading

try:
    from argparse import ArgumentParser, REMAINDER
    using_optparse = False
except ImportError:
    from optparse import OptionParser as ArgumentParser
    using_optparse = True

try:
    import psutil
    have_psutil = True
except ImportError:
    have_psutil = False

# ProcessLookupError was added in python 3.3
try:
    ProcessLookupError
except NameError:
    ProcessLookupError = OSError

VERSION='0.7'

global debuglevel
debuglevel = 0

def DEBUG(level, *msg, **args):
    end = args.get('end', '\n')
    if level<=debuglevel:
        print(*msg, file=sys.stderr, end=end)
        sys.stderr.flush()


def children_pid_psutil(pid):
    "get the list of the children (and grand and grandgrand...) pids of a pid (psutil)"
    try:
        p = psutil.Process(pid)
        children = p.get_children(recursive=True)
        children = set(x.pid for x in children)
    except psutil.NoSuchProcess:
        children = set()
    return children

def children_pid_ps(pid):
    "get the list of the children (and grand and grandgrand...) pids of a pid"

    # get pid, parent_pid
    proc = subprocess.Popen('ps ax -o pid,ppid', shell=True, stdout=subprocess.PIPE)
    pidppid = [x.split() for x in proc.communicate()[0].splitlines()[1:] if x]
    pidppid = [(int(x[0]), int(x[1])) for x in pidppid]
    children = set()
    while True:
        prev_len = len(children)
        for p, pp in pidppid:
            if pp == pid or pp in children:
                children.add(p)
        if len(children) == prev_len:
            break
    return children

if have_psutil:
    children_pid = children_pid_psutil
else:
    children_pid = children_pid_ps


def kill_children(pid, sig, raise_parent):
    "kill all the children of pid with signal, including the parent process"
    "if raise_parent is true, missing pid process will raise ProcessLookupError, to signal it does not exist anymore"
    # first kill the main process
    children = children_pid(pid)
    if raise_parent:
        DEBUG(3, 'killing pid', pid, 'with signal', sig)
        os.kill(pid, sig)
    else:
        try:
            DEBUG(3, 'killing pid', pid, 'with signal', sig)
            os.kill(pid, sig)
        except ProcessLookupError:
            pass

    # and then the children
    for childpid in children:
        try:
            DEBUG(3, 'killing pid', childpid, 'with signal', sig)
            os.kill(childpid, sig)
        except ProcessLookupError:
            # if the child disappeared in the meantime
            pass

def renice_pid_psutil(pid, niceness, ioclass, ioclassdata):
    try:
        p = psutil.Process(pid)
        if niceness is not None:
            DEBUG(2, 'setting scheduler priority of {} to {}'.format(pid, niceness))
            p.set_nice(niceness)
        if (ioclass is not None) and ('set_ionice' in p.__dict__):
            DEBUG(2, 'setting i/o class of {} to {}, classdata {}'.format(pid, ioclass, ioclassdata))
            p.set_ionice(ioclass, ioclassdata)
    except ProcessLookupError:
        # if the process disappeared in the meantime
        pass

def suppress_output(cmd):
    # subprocess.check_output wrapper, to suppress and debug the output
    try:
        output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
        output = output.decode('utf-8')
        DEBUG(3, output)
    except subprocess.CalledProcessError as cperr:
        output = cperr.output.decode('utf-8')
        DEBUG(2, 'error:', output)
    except FileNotFoundError:
        DEBUG(2, 'command not found:', cmd)

def renice_pid_fallback(pid, niceness, ioclass, ioclassdata):
    if niceness is not None:
        DEBUG(2, 'setting scheduler priority of {} to {}'.format(pid, niceness))
        cmd = ['renice', '-n', str(niceness), '-p', str(pid)]
        DEBUG(2, 'calling '+' '.join(cmd))
        suppress_output(cmd)
    if ioclass is not None:
        if sys.platform in ('linux', 'linux2'):
            DEBUG(2, 'setting i/o class of {} to {}, classdata {}'.format(pid, ioclass, ioclassdata))
            cmd = ['ionice', '-c', str(ioclass)]
            if ioclassdata is not None:
                cmd += ['-n', str(ioclassdata)]
            cmd += ['-p', str(pid)]
            DEBUG(2, 'calling '+' '.join(cmd))
            suppress_output(cmd)
        else:
            DEBUG(1, 'ionice not supported on non-Linux platforms')

# set priorities using only python library modules
def renice_pid_internal(pid, niceness, ioclass, ioclassdata):
    # os.setpriority has been added in python3.3
    if niceness is not None:
        if 'setpriority' in os.__dict__:
            try:
                DEBUG(2, 'setting scheduler priority of {} to {}'.format(pid, niceness))
                os.setpriority(os.PRIO_PROCESS, pid, niceness)
            except ProcessLookupError:
                pass
        else:
            DEBUG(1, 'psutil not installed, renice ignored')
    if ioclass is not None or ioclassdata is not None:
        DEBUG(1, 'psutil not installed, ionice ignored')

if have_psutil:
    renice_pid = renice_pid_psutil
elif sys.platform in ('linux', 'linux2', 'darwin', 'gnukfreebsd8'):
    renice_pid = renice_pid_fallback
else:
    renice_pid = renice_pid_internal

# recursively renice all the children of a given pid (and the pid itself)
def renice_children_pid(pid, niceness, ioclass, ioclassdata):
    children = children_pid(pid)
    for childpid in children | set([pid]):
        renice_pid(childpid, niceness, ioclass, ioclassdata)


def watcher(pid, run_period, sleep_period, termafter, killafter):
    start_time = time.time()
    term_time = 0 # time when the process has been sent a TERM signal
    DEBUG(3, 'start time', start_time)
    while True:
        if termafter is not None:
            now = time.time()
            if now-start_time > termafter:
                if not term_time:
                    DEBUG(1, 'T', end='\b')
                    DEBUG(3, 'Terminating processes after', termafter, 's')
                    kill_children(pid, signal.SIGTERM, raise_parent=True)
                    term_time = now
                elif killafter is not None:
                    # term_time is set, processes have been sent the TERM signal already
                    now = time.time()
                    if now-term_time > killafter:
                        DEBUG(1, 'X', end='\b')
                        DEBUG(3, 'Killing processes after', termafter, '+', killafter, 's')
                        kill_children(pid, signal.SIGKILL, raise_parent=True)
        children = children_pid(pid)
        time.sleep(run_period)
        DEBUG(1, '$', end='\b')
        DEBUG(2, time.asctime(), 'STOP:', pid,  children)
        try:
            kill_children(pid, signal.SIGSTOP, raise_parent=True)
            time.sleep(sleep_period)
            DEBUG(1, '=', end='\b')
            DEBUG(2, time.asctime(), 'CONT:', pid, children)
        except ProcessLookupError:
            break
        finally:
            # be careful to leave the processes running if something bad happens
            kill_children(pid, signal.SIGCONT, raise_parent=False)


def run_subprocess(cmd, run_period, sleep_period, niceness, ioclass, ioclassdata, termafter, killafter):
    p = subprocess.Popen(cmd, shell=True)
    try:
        # renice only once, all the newly created children will inherit the nice values
        renice_children_pid(p.pid, niceness, ioclass, ioclassdata)

        t = threading.Thread(target=watcher, args = (p.pid, run_period, sleep_period, termafter, killafter))
        t.daemon = True
        t.start()
        try:
            p.communicate()
        except KeyboardInterrupt:
            try:
                DEBUG(2, '\nCTRL C detected')
                # unstop all children
                children = children_pid(p.pid)
                kill_children(p.pid, signal.SIGCONT, raise_parent=False)
                # send SIGINT to the main process
                try:
                    os.kill(p.pid, signal.SIGINT)
                except ProcessLookupError:
                    pass
    #            status = os.wait()[1]
    #            DEBUG(2, 'Process ended with status', status)
    #            sys.exit(status)
                time.sleep(0.5) # just in case, to give the main process time to clean up
            except KeyboardInterrupt:
                # if the user pressed CTRL C quickly twice
                DEBUG(2, '\nCTRL C pressed multiple times, killing processes and exiting')
                try:
                    os.kill(p.pid, signal.SIGINT)
                except ProcessLookupError:
                    pass
                sys.exit(1)
    finally:
        # if the spawned process still hangs around, ask it politely to finish
        try:
            os.kill(p.pid, signal.SIGTERM)
        except ProcessLookupError:
            pass


def watch_pid(pid, run_period, sleep_period, termafter, killafter):
    try:
        watcher(pid, run_period, sleep_period, termafter, killafter)
    except KeyboardInterrupt:
        # catch ctrl C
        # unstop processes
        try:
            children = children_pid(pid)
            DEBUG(1, '\nCTRL C detected, ending...')
            kill_children(pid, signal.SIGCONT, raise_parent=False)
        except KeyboardInterrupt:
             # if the user pressed CTRL C quickly twice
             DEBUG(2, '\nCTRL C pressed multiple times, exiting')
             sys.exit(1)
        # and finish, letting them run...
        return
    finally:
        # if anything bad happened, at least try to unstop processes
        try:
            children = children_pid(pid)
            kill_children(pid, signal.SIGCONT, raise_parent=False)
        except:
            pass

def get_duration(x):
    "convert duration - a number with 's', 'm', 'h', 'd' or 'S' suffix to seconds"
    if not x: # None or empty string
        return x
    suffix = 's'
    conv = { 's': 1,
             'm': 60,
             'h': 3600,
             'd': 24*3600+0.002,
             'S': 24*3600+39*60+35.24409
            }
    if x[-1] in conv:
        suffix = x[-1]
        x = x[:-1]

    d = float(x)
    d = conv[suffix] * d
    return d


if __name__ == '__main__':

    parser = ArgumentParser(prog='polynice',
            description='Aggresively throttle processes.')
    if using_optparse:
        DEBUG(3, 'using optparse')
        parser.add_argument = parser.add_option
        parser.parse_known_args = parser.parse_args
        parser.disable_interspersed_args()

    parser.add_argument('-v', dest='debuglevel', action='count',
                       default = 0,
                       help='Be verbose (repeat for more)')

    parser.add_argument('-c', '--class', dest='ioclass', action='store',
                       type=int,
                       default = None,
                       help='I/O scheduling class number (the same as ionice): 0: none, 1: realtime, 2: best-effort, 3: idle')

    parser.add_argument('--ion', '--classdata', dest='ioclassdata', action='store',
                       type=int,
                       default=None,
                       help='I/O scheduling class data: 0-7 for realtime and best-effort classes')

    parser.add_argument('--nn', '--niceness', dest='niceness', action='store',
                       type=int,
                       default=None,
                       help='Scheduling niceness (the same as for nice)')

    parser.add_argument('--term-after', dest='termafter', action='store',
                       type=str,
                       default=None,
                       help='Terminate the process (and its children) if it is still running after these many seconds')

    parser.add_argument('--kill-after', dest='killafter', action='store',
                       type=str,
                       default=None,
                       help='Kill the process (and its children) if it is still running after these many seconds after termination signal has been sent (with --term-after)')


    parser.add_argument('-n', dest='veryniceness', action='store',
                       type=str, default='1/4',
                       help='Throttling value (sleep/run), e.g. 1, 5, 1/3, 0.4')

    parser.add_argument('-a', dest='all_nice', action='store_true',
                       help='Set maximum possible nice values (equivalent to ionice -c 3 nice -n 20)')

    parser.add_argument('-p', dest='pid', action='store',
                       type=int,
                       help='Be nice to this pid (and all the children)')


    if not using_optparse:
        parser.add_argument('--version', action='version',
                       version='%(prog)s '+VERSION)
        parser.add_argument('command', help='Command to run', default=None, nargs='?')
        parser.add_argument('args', nargs=REMAINDER)

    args, rest = parser.parse_known_args()

    debuglevel = args.debuglevel
    DEBUG(3, 'args:', str(args))
    DEBUG(3, 'optparse:', using_optparse)
    DEBUG(3, 'psutil available:', have_psutil)
    sleep_period = 1.0
    run_period = 1.0
    if '/' in args.veryniceness:
        sleep_period, run_period = args.veryniceness.split('/')
        sleep_period = float(sleep_period)
        run_period = float(run_period)
    else:
        sleep_period = float(args.veryniceness)

    # nice(1) emulation
    niceness = args.niceness

    # ionice emulation
    ioclass = args.ioclass
    ioclassdata = args.ioclassdata

    # timeout(1) emulation
    termafter = get_duration(args.termafter)
    killafter = get_duration(args.killafter)

    if niceness is None and args.all_nice:
        niceness = 20
    if ioclass is None and args.all_nice:
        ioclass = 3

    DEBUG(2, 'debuglevel', debuglevel)
    DEBUG(2, 'sleep period:', sleep_period)
    DEBUG(2, 'run period:', run_period)
    DEBUG(2, 'niceness:', niceness)
    DEBUG(2, 'ioclass:', ioclass)
    DEBUG(2, 'ioclassdata:', ioclassdata)
    DEBUG(2, 'pid:', args.pid)
    DEBUG(2, 'term-after:', termafter)
    DEBUG(2, 'kill-after:', killafter)

    if 'command' in dir(args) and args.command:
        cmd = args.command + ' '+ ' '.join(args.args)
    else:
        cmd = ' '.join(rest)
    DEBUG(1, 'CMD:', cmd)


    if args.pid and cmd:
        print ('You cannot combine -p and command.')
        print()
        parser.print_help()
        sys.exit(1)
    if args.pid == 1:
        print ('Cowardly refusing to suspend init(8) and its children.')
        sys.exit(1)
    if args.pid:
        renice_children_pid(args.pid, niceness, ioclass, args.ioclassdata)
        watch_pid(args.pid, run_period, sleep_period, termafter, killafter)
    elif cmd:
        run_subprocess(cmd, run_period, sleep_period, niceness, ioclass, ioclassdata, termafter, killafter)
    else:
        print ('usage: polynice [options] command')
        parser.print_help()

