#!/usr/bin/env python3
# Govibe Disk Power Management limited administrator helper.
# This helper is intentionally narrow: it only accepts disk actions used by the GUI.

import json
import os
import re
import subprocess
import sys

SAFE_BLOCK_NAME = re.compile(r'^(sd[a-z][0-9]*|hd[a-z][0-9]*|vd[a-z][0-9]*|xvd[a-z][0-9]*|nvme[0-9]+n[0-9]+(p[0-9]+)?|mmcblk[0-9]+(p[0-9]+)?|sr[0-9]+)$')
PROTECTED_EXACT = {'/', '/boot', '/boot/efi', '/home', '/usr', '/var', '/opt', '/srv'}
PROTECTED_PREFIXES = ('/snap/', '/run/', '/sys/', '/proc/', '/dev/')


def fail(message, code=2):
    print(message, file=sys.stderr)
    sys.exit(code)


def run(args, timeout=120):
    try:
        proc = subprocess.run(args, text=True, capture_output=True, timeout=timeout)
        return proc.returncode, proc.stdout.strip(), proc.stderr.strip()
    except subprocess.TimeoutExpired:
        return 124, '', 'Command timed out'
    except Exception as exc:
        return 1, '', str(exc)


def safe_device(path):
    if not isinstance(path, str) or not path.startswith('/dev/'):
        fail('Blocked unsafe device path.')
    if any(c not in 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789/_-.:+' for c in path):
        fail('Blocked unsafe characters in device path.')
    real = os.path.realpath(path)
    if not real.startswith('/dev/'):
        fail('Blocked device path outside /dev.')
    base = os.path.basename(real)
    if not SAFE_BLOCK_NAME.match(base):
        fail('Blocked unsupported block device name: ' + base)
    if not os.path.exists(real):
        fail('Device does not exist: ' + real)
    return real


def flatten_mountpoints(mp):
    if mp is None:
        return []
    if isinstance(mp, list):
        return [str(x) for x in mp if x]
    if isinstance(mp, str):
        if not mp:
            return []
        return [x for x in mp.split('\n') if x]
    return []


def is_system_mount(mountpoint):
    if not mountpoint:
        return False
    if mountpoint in PROTECTED_EXACT:
        return True
    return mountpoint.startswith(PROTECTED_PREFIXES)


def collect_mountpoints(node):
    points = []
    points.extend(flatten_mountpoints(node.get('mountpoints')))
    for child in node.get('children') or []:
        if isinstance(child, dict):
            points.extend(collect_mountpoints(child))
    return points


def disk_is_system(node):
    return any(is_system_mount(mp) for mp in collect_mountpoints(node))


def lsblk_tree():
    code, out, err = run(['lsblk', '-J', '-o', 'NAME,PATH,TYPE,MOUNTPOINTS'], timeout=20)
    if code != 0:
        fail(err or 'Could not read lsblk.')
    try:
        return json.loads(out).get('blockdevices', [])
    except Exception as exc:
        fail('Could not parse lsblk JSON: ' + str(exc))


def node_contains_path(node, path):
    if os.path.realpath(str(node.get('path') or '')) == path:
        return True
    for child in node.get('children') or []:
        if isinstance(child, dict) and node_contains_path(child, path):
            return True
    return False


def parent_disk_node(path):
    for node in lsblk_tree():
        if not isinstance(node, dict) or node.get('type') != 'disk':
            continue
        if node_contains_path(node, path):
            return node
    return None


def reject_if_system_related(path):
    disk = parent_disk_node(path)
    if disk and disk_is_system(disk):
        fail('Blocked: selected device belongs to a protected system disk.')


def command_output(code, out, err):
    if out:
        print(out)
    if err:
        print(err, file=sys.stderr)
    return code


def action_unmount(devices):
    if not devices:
        fail('No partition was provided for unmount.')
    messages = []
    errors = []
    code_final = 0
    run(['sync'], timeout=30)
    for item in devices:
        dev = safe_device(item)
        reject_if_system_related(dev)
        code, out, err = run(['udisksctl', 'unmount', '--no-user-interaction', '-b', dev], timeout=90)
        if code != 0:
            code, out2, err2 = run(['umount', dev], timeout=90)
            out = '\n'.join(x for x in [out, out2] if x)
            err = '\n'.join(x for x in [err, err2] if x)
        if out:
            messages.append(out)
        if err:
            errors.append(err)
        if code != 0:
            code_final = code
    if messages:
        print('\n'.join(messages))
    if errors:
        print('\n'.join(errors), file=sys.stderr)
    return code_final


def action_sleep(device):
    dev = safe_device(device)
    reject_if_system_related(dev)
    run(['sync'], timeout=30)
    code, out, err = run(['hdparm', '-y', dev], timeout=60)
    return command_output(code, out, err)


def action_poweroff(device):
    dev = safe_device(device)
    reject_if_system_related(dev)
    run(['sync'], timeout=30)
    code, out, err = run(['udisksctl', 'power-off', '--no-user-interaction', '-b', dev], timeout=120)
    return command_output(code, out, err)


def action_rescan():
    errors = []
    for host in sorted(os.listdir('/sys/class/scsi_host')) if os.path.isdir('/sys/class/scsi_host') else []:
        scan = os.path.join('/sys/class/scsi_host', host, 'scan')
        if os.path.exists(scan):
            try:
                with open(scan, 'w', encoding='utf-8') as fh:
                    fh.write('- - -\n')
            except Exception as exc:
                errors.append(f'{scan}: {exc}')
    code, out, err = run(['udevadm', 'settle'], timeout=60)
    if out:
        print(out)
    if err:
        errors.append(err)
    if errors:
        print('\n'.join(errors), file=sys.stderr)
    return 0 if code == 0 else code


def action_mount(device):
    dev = safe_device(device)
    reject_if_system_related(dev)
    code, out, err = run(['udisksctl', 'mount', '--no-user-interaction', '-b', dev], timeout=120)
    return command_output(code, out, err)


def main():
    if len(sys.argv) < 2:
        fail('Missing action.')
    action = sys.argv[1]
    if action == 'ping':
        print('Administrator helper is ready.')
        return 0
    if action == 'unmount':
        return action_unmount(sys.argv[2:])
    if action == 'sleep':
        if len(sys.argv) != 3:
            fail('Sleep requires one disk path.')
        return action_sleep(sys.argv[2])
    if action == 'poweroff':
        if len(sys.argv) != 3:
            fail('Power off requires one disk path.')
        return action_poweroff(sys.argv[2])
    if action == 'rescan':
        return action_rescan()
    if action == 'mount':
        if len(sys.argv) != 3:
            fail('Mount requires one partition path.')
        return action_mount(sys.argv[2])
    fail('Unknown action: ' + action)


if __name__ == '__main__':
    sys.exit(main())
