#!/usr/bin/env python3
#
#   This program is free software; you can redistribute it and/or modify
#   it under the terms of the GNU General Public License as published by
#   the Free Software Foundation; either version 2 of the License, or
#   (at your option) any later version.
#
#   This program is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU General Public License for more details.
#
#   You should have received a copy of the GNU General Public License along
#   with this program; if not, write to the Free Software Foundation, Inc.,
#   51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
import shutil
import subprocess
import re
import sys
import os

os.environ["PATH"] = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/sbin:/usr/X11R6/bin"


def check_ntpd(what):
    # $ ntpq  -c peers
    # remote           refid      st t when poll reach   delay   offset  jitter
    # ==============================================================================
    # 0.ubuntu.pool.n .POOL.          16 p    -   64    0    0.000    0.000   0.000
    # 1.ubuntu.pool.n .POOL.          16 p    -   64    0    0.000    0.000   0.000
    # 2.ubuntu.pool.n .POOL.          16 p    -   64    0    0.000    0.000   0.000
    # 3.ubuntu.pool.n .POOL.          16 p    -   64    0    0.000    0.000   0.000
    # ntp.ubuntu.com  .POOL.          16 p    -   64    0    0.000    0.000   0.000
    # -ntp1.hetzner.de 205.46.178.169   2 u  911 1024  377    2.604   -0.017   0.012
    # -ntp2.hetzner.de 237.17.204.95    2 u 1046 1024  377    0.297   -0.346   0.193
    # -ntp3.hetzner.de 124.216.164.14   2 u  804 1024  377    0.315   -0.015   0.014
    # +server1b.meinbe 79.133.44.146    2 u  484 1024  377    2.641   -0.054   0.068
    # +ntp1.lwlcom.net .GPS.            1 u  720 1024  377   10.659   -0.080   0.050
    # *ntp4.lwlcom.net .GPS.            1 u  224 1024  377   10.944   -0.094   0.064
    #

    server_count = 0
    cmd = f"chroot {get_root()} ntpq -n -c peers"
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    out, err = process.communicate()
    for line in out.decode('utf8').splitlines():
        m = re.match(
            r"^(?P<mode>\*|\+|o|x|\.|\#)(.+?)\s+.*(?P<delay>[-+]?\d+\.?\d*)\s+(?P<offset>[-+]?\d+\.?\d*)\s+(?P<jitter>[-+]?\d+\.?\d*)$",
            line)
        if not m:
            continue

        if sys.stderr.isatty():
            sys.stderr.write("match:%s\n" % line.rstrip())
        server_count += 1

        if what in ["offset", "jitter", "delay"] and m.group("mode") == "*":
            sys.stdout.write(m.group(what))
    if what == "count":
        sys.stdout.write(str(server_count))


def check_chronyc(what):
    # $ chronyc sources
    # MS Name/IP address         Stratum Poll Reach LastRx Last sample
    # ===============================================================================
    # ^+ prod-ntp-3.ntp1.ps5.cano>     2   6   177    32    +17us[  +17us] +/-   14ms
    # ^+ prod-ntp-4.ntp1.ps5.cano>     2   6   177    31   -187us[ -187us] +/-   11ms
    # ^+ prod-ntp-3.ntp1.ps5.cano>     2   6   177    31   +470us[ +470us] +/-   16ms
    # ^- alphyn.canonical.com          2   6   177    32  +1799us[+1799us] +/-   88ms
    # ^- 168.119.4.163.polisystem>     2   6   177    31  +2831us[+2831us] +/-   54ms
    # ^+ bblock.dev                    2   6   177    33  +1994us[+1994us] +/-   15ms
    # ^+ 2003:a:47f:abe4:48ba:cd4>     2   6   177    33   -726us[ -726us] +/- 8257us
    # ^* time.cloudflare.com           3   6   177    33    -16us[  -24us] +/- 7737us
    #

    # $ chronyc tracking
    # Reference ID    : E1FE1EBE (time.cloudflare.com)
    # Stratum         : 4
    # Ref time (UTC)  : Fri Sep 15 10:01:50 2023
    # System time     : 0.000092172 seconds fast of NTP time
    # Last offset     : +0.000214544 seconds
    # RMS offset      : 0.001256659 seconds
    # Frequency       : 7.671 ppm slow
    # Residual freq   : +0.119 ppm
    # Skew            : 3.602 ppm
    # Root delay      : 0.014633909 seconds
    # Root dispersion : 0.000646871 seconds
    # Update interval : 65.2 seconds
    # Leap status     : Normal
    #

    server_count = 0

    if what == "count":
        cmd = f"chroot {get_root()} chronyc -c sources"
        process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
        out, err = process.communicate()
        for line in out.decode('utf8').splitlines():
            values = line.split(",")
            if values[1] not in ["*", "+", "o", "x", ".", "#"]:
                continue
            if sys.stderr.isatty():
                sys.stderr.write("match:%s\n" % line.rstrip())
            server_count += 1
        sys.stdout.write(str(server_count))
    else:
        cmd = f"chroot {get_root()} chronyc -c tracking"
        process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
        out, err = process.communicate()
        for line in out.decode('utf8').splitlines():
            values = line.split(",")
            if what == "offset":
                sys.stdout.write(str(float(values[4]) * 1000))
            elif what == "jitter":
                sys.stdout.write(str(float(values[9]) * 1000))
            elif what == "delay":
                sys.stdout.write(str(float(values[9]) * 1000))


def convert_time_systemd(v):
    time_ms = float(0.0)
    for part in v.split(" "):
        if part.endswith("us"):
            time_ms += float(part[:-2]) / float(1000)
        elif part.endswith("ms"):
            time_ms += float(part[:-2])
        elif part.endswith("s"):
            time_ms += float(part[:-1]) * 1000
        elif part.endswith("min"):
            time_ms += float(part[:-3]) * 1000 * 60
        elif part.endswith("h"):
            time_ms += float(part[:-1]) * 1000 * 60 * 60
        elif part.endswith("d"):
            time_ms += float(part[:-1]) * 1000 * 60 * 60 * 24
        else:
            raise Exception("Unknown time format: %s" % part)

    return str(time_ms)


def check_systemd(what):
    # $ timedatectl show-timesync
    # FallbackNTPServers=ntp.ubuntu.com
    # ServerName=ntp.ubuntu.com
    # ServerAddress=2620:2d:4000:1::3f
    # RootDistanceMaxUSec=5s
    # PollIntervalMinUSec=32s
    # PollIntervalMaxUSec=34min 8s
    # PollIntervalUSec=17min 4s
    # NTPMessage={ Leap=0, Version=4, Mode=4, Stratum=2, Precision=-25, RootDelay=8.926ms, RootDispersion=1.770ms, Reference=91EECB0E, OriginateTimestamp=Fri 2023-09-15 10:05:40 CEST, ReceiveTimestamp=Fri 2023-09-15 10:05:40 CEST, TransmitTimestamp=Fri 2023-09-15 10:05:40 CEST, DestinationTimestamp=Fri 2023-09-15 10:05:40 CEST, Ignored=no PacketCount=25, Jitter=279.841ms }
    # Frequency=446926

    # Systemd does not support multiple servers, therefore we only check the first one and
    # return a unrealistic value
    server_count = 99
    if what == "count":
        sys.stdout.write(str(server_count))
        return

    cmd = f"chroot {get_root()} timedatectl show-timesync"
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    out, err = process.communicate()
    for line in out.decode('utf8').splitlines():
        m = re.match(r"^NTPMessage=\{(..+)\}", line)
        if not m:
            continue

        if sys.stderr.isatty():
            sys.stderr.write("match:%s\n" % line.rstrip())

        items = re.split(r"\s*,\s*", m.group(1))
        result = None
        for item in items:
            m_item = re.match(r"^(.+?)=(.+)$", item.strip())
            k = m_item.group(1)
            v = m_item.group(2)
            if what == "jitter" and k == "Jitter":
                sys.stdout.write(convert_time_systemd(v))
            elif what == "delay" and k == "RootDelay":
                sys.stdout.write(convert_time_systemd(v))
            elif what == "offset" and k == "RootDispersion":
                sys.stdout.write(convert_time_systemd(v))

def get_root():
  if os.path.isdir("/rootfs"):
    return "/rootfs"
  else:
    return "/" 


def which(tool: str):
  result = subprocess.run(
      ["chroot", get_root(), "which", tool],
      capture_output=True,
      text=True
  )
  if result.returncode == 0:
    return True
  else:
    return False


if len(sys.argv) != 2:
    print(sys.argv[0] + " offset|jitter|delay|count")
    sys.exit(1)

what = sys.argv[1]

if which("ntpq"):
    check_ntpd(what)
elif which("chronyc"):
    check_chronyc(what)
elif which("timedatectl"):
    check_systemd(what)
else:
    print("No ntp client implementation found", file=sys.stderr)
    sys.exit(1)

# vim: ai et ts=2 shiftwidth=2 expandtab
