Files
TerraHome/nixos/kubeadm/scripts/discover-inventory-from-ssh.py

183 lines
6.7 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
import concurrent.futures
import ipaddress
import json
import os
import subprocess
import sys
from typing import Dict, Set, Tuple
def derive_prefix(payload: dict) -> str:
explicit = os.environ.get("KUBEADM_SUBNET_PREFIX", "").strip()
if explicit:
return explicit
for key in ("control_plane_vm_ipv4", "worker_vm_ipv4"):
values = payload.get(key, {}).get("value", {})
for ip in values.values():
if ip:
parts = ip.split(".")
if len(parts) == 4:
return ".".join(parts[:3])
return "10.27.27"
def ssh_probe(ip: str, users: list[str], key_path: str, timeout_sec: int) -> Tuple[str, str, str] | None:
cmd_tail = [
"-o",
"BatchMode=yes",
"-o",
"IdentitiesOnly=yes",
"-o",
"StrictHostKeyChecking=accept-new",
"-o",
f"ConnectTimeout={timeout_sec}",
"-i",
key_path,
]
for user in users:
cmd = [
"ssh",
*cmd_tail,
f"{user}@{ip}",
"hn=$(hostnamectl --static 2>/dev/null || hostname); serial=$(cat /sys/class/dmi/id/product_serial 2>/dev/null || true); printf '%s|%s\n' \"$hn\" \"$serial\"",
]
try:
out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL, text=True, timeout=timeout_sec + 2).strip()
except Exception:
continue
if out:
line = out.splitlines()[0].strip()
if "|" in line:
host, serial = line.split("|", 1)
else:
host, serial = line, ""
return host.strip(), ip, serial.strip()
return None
def build_inventory(names: Set[str], found: Dict[str, str], ssh_user: str) -> str:
cp = sorted([n for n in names if n.startswith("cp-")], key=lambda x: int(x.split("-")[1]))
wk = sorted([n for n in names if n.startswith("wk-")], key=lambda x: int(x.split("-")[1]))
cp_pairs = " ".join(f"{n}={found[n]}" for n in cp)
wk_pairs = " ".join(f"{n}={found[n]}" for n in wk)
primary = cp[0] if cp else "cp-1"
return "\n".join(
[
f"SSH_USER={ssh_user}",
f"PRIMARY_CONTROL_PLANE={primary}",
f'CONTROL_PLANES="{cp_pairs}"',
f'WORKERS="{wk_pairs}"',
"",
]
)
def main() -> int:
payload = json.load(sys.stdin)
cp_names = set(payload.get("control_plane_vm_ids", {}).get("value", {}).keys())
wk_names = set(payload.get("worker_vm_ids", {}).get("value", {}).keys())
target_names = cp_names | wk_names
if not target_names:
raise SystemExit("Could not determine target node names from Terraform outputs")
ssh_user = os.environ.get("KUBEADM_SSH_USER", "").strip() or "micqdf"
users = [u for u in os.environ.get("SSH_USER_CANDIDATES", f"{ssh_user} root").split() if u]
key_path = os.environ.get("SSH_KEY_PATH", os.path.expanduser("~/.ssh/id_ed25519"))
timeout_sec = int(os.environ.get("SSH_DISCOVERY_TIMEOUT_SEC", "6"))
max_workers = int(os.environ.get("SSH_DISCOVERY_WORKERS", "32"))
prefix = derive_prefix(payload)
start = int(os.environ.get("KUBEADM_SUBNET_START", "2"))
end = int(os.environ.get("KUBEADM_SUBNET_END", "254"))
vip_suffix = int(os.environ.get("KUBEADM_CONTROL_PLANE_VIP_SUFFIX", "250"))
def is_vip_ip(ip: str) -> bool:
try:
return int(ip.split(".")[-1]) == vip_suffix
except Exception:
return False
scan_ips = [
str(ipaddress.IPv4Address(f"{prefix}.{i}"))
for i in range(start, end + 1)
if i != vip_suffix
]
found: Dict[str, str] = {}
vmid_to_name: Dict[str, str] = {}
for name, vmid in payload.get("control_plane_vm_ids", {}).get("value", {}).items():
vmid_to_name[str(vmid)] = name
for name, vmid in payload.get("worker_vm_ids", {}).get("value", {}).items():
vmid_to_name[str(vmid)] = name
seen_hostnames: Dict[str, str] = {}
seen_ips: Dict[str, Tuple[str, str]] = {}
def run_pass(pass_timeout: int, pass_workers: int) -> None:
with concurrent.futures.ThreadPoolExecutor(max_workers=pass_workers) as pool:
futures = [pool.submit(ssh_probe, ip, users, key_path, pass_timeout) for ip in scan_ips]
for fut in concurrent.futures.as_completed(futures):
result = fut.result()
if not result:
continue
host, ip, serial = result
if host not in seen_hostnames:
seen_hostnames[host] = ip
if ip not in seen_ips:
seen_ips[ip] = (host, serial)
target = None
if serial in vmid_to_name:
inferred = vmid_to_name[serial]
target = inferred
elif host in target_names:
target = host
if target:
existing = found.get(target)
if existing is None or (is_vip_ip(existing) and not is_vip_ip(ip)):
found[target] = ip
if all(name in found for name in target_names):
return
run_pass(timeout_sec, max_workers)
if not all(name in found for name in target_names):
# Slower second pass for busy runners/networks.
run_pass(max(timeout_sec + 2, 8), max(8, max_workers // 2))
# Heuristic fallback: if nodes still missing, assign from remaining SSH-reachable
# IPs not already used, ordered by IP. This helps when cloned nodes temporarily
# share a generic hostname (e.g. "flex") and DMI serial mapping is unavailable.
missing = sorted([n for n in target_names if n not in found])
if missing:
used_ips = set(found.values())
candidates = sorted(ip for ip in seen_ips.keys() if ip not in used_ips)
if len(candidates) >= len(missing):
for name, ip in zip(missing, candidates):
found[name] = ip
missing = sorted([n for n in target_names if n not in found])
if missing:
discovered = ", ".join(sorted(seen_hostnames.keys())[:20])
if discovered:
sys.stderr.write(f"Discovered hostnames during scan: {discovered}\n")
if seen_ips:
sample = ", ".join(f"{ip}={meta[0]}" for ip, meta in list(sorted(seen_ips.items()))[:20])
sys.stderr.write(f"SSH-reachable IPs: {sample}\n")
raise SystemExit(
"Failed SSH-based IP discovery for nodes: " + ", ".join(missing) +
f" (scanned {prefix}.{start}-{prefix}.{end})"
)
sys.stdout.write(build_inventory(target_names, found, ssh_user))
return 0
if __name__ == "__main__":
raise SystemExit(main())