#!/usr/bin/env python3 import concurrent.futures import ipaddress import json import os import subprocess import sys from typing import Dict, Set, Tuple def derive_prefix(payload: dict) -> str: explicit = os.environ.get("KUBEADM_SUBNET_PREFIX", "").strip() if explicit: return explicit for key in ("control_plane_vm_ipv4", "worker_vm_ipv4"): values = payload.get(key, {}).get("value", {}) for ip in values.values(): if ip: parts = ip.split(".") if len(parts) == 4: return ".".join(parts[:3]) return "10.27.27" def ssh_probe(ip: str, users: list[str], key_path: str, timeout_sec: int) -> Tuple[str, str, str] | None: cmd_tail = [ "-o", "BatchMode=yes", "-o", "IdentitiesOnly=yes", "-o", "StrictHostKeyChecking=accept-new", "-o", f"ConnectTimeout={timeout_sec}", "-i", key_path, ] for user in users: cmd = [ "ssh", *cmd_tail, f"{user}@{ip}", "hn=$(hostnamectl --static 2>/dev/null || hostname); serial=$(cat /sys/class/dmi/id/product_serial 2>/dev/null || true); printf '%s|%s\n' \"$hn\" \"$serial\"", ] try: out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL, text=True, timeout=timeout_sec + 2).strip() except Exception: continue if out: line = out.splitlines()[0].strip() if "|" in line: host, serial = line.split("|", 1) else: host, serial = line, "" return host.strip(), ip, serial.strip() return None def build_inventory(names: Set[str], found: Dict[str, str], ssh_user: str) -> str: cp = sorted([n for n in names if n.startswith("cp-")], key=lambda x: int(x.split("-")[1])) wk = sorted([n for n in names if n.startswith("wk-")], key=lambda x: int(x.split("-")[1])) cp_pairs = " ".join(f"{n}={found[n]}" for n in cp) wk_pairs = " ".join(f"{n}={found[n]}" for n in wk) primary = cp[0] if cp else "cp-1" return "\n".join( [ f"SSH_USER={ssh_user}", f"PRIMARY_CONTROL_PLANE={primary}", f'CONTROL_PLANES="{cp_pairs}"', f'WORKERS="{wk_pairs}"', "", ] ) def main() -> int: payload = json.load(sys.stdin) cp_names = set(payload.get("control_plane_vm_ids", {}).get("value", {}).keys()) wk_names = set(payload.get("worker_vm_ids", {}).get("value", {}).keys()) target_names = cp_names | wk_names if not target_names: raise SystemExit("Could not determine target node names from Terraform outputs") ssh_user = os.environ.get("KUBEADM_SSH_USER", "").strip() or "micqdf" users = [u for u in os.environ.get("SSH_USER_CANDIDATES", f"{ssh_user} root").split() if u] key_path = os.environ.get("SSH_KEY_PATH", os.path.expanduser("~/.ssh/id_ed25519")) timeout_sec = int(os.environ.get("SSH_DISCOVERY_TIMEOUT_SEC", "6")) max_workers = int(os.environ.get("SSH_DISCOVERY_WORKERS", "32")) prefix = derive_prefix(payload) start = int(os.environ.get("KUBEADM_SUBNET_START", "2")) end = int(os.environ.get("KUBEADM_SUBNET_END", "254")) vip_suffix = int(os.environ.get("KUBEADM_CONTROL_PLANE_VIP_SUFFIX", "250")) def is_vip_ip(ip: str) -> bool: try: return int(ip.split(".")[-1]) == vip_suffix except Exception: return False scan_ips = [ str(ipaddress.IPv4Address(f"{prefix}.{i}")) for i in range(start, end + 1) if i != vip_suffix ] found: Dict[str, str] = {} vmid_to_name: Dict[str, str] = {} for name, vmid in payload.get("control_plane_vm_ids", {}).get("value", {}).items(): vmid_to_name[str(vmid)] = name for name, vmid in payload.get("worker_vm_ids", {}).get("value", {}).items(): vmid_to_name[str(vmid)] = name seen_hostnames: Dict[str, str] = {} seen_ips: Dict[str, Tuple[str, str]] = {} def run_pass(pass_timeout: int, pass_workers: int) -> None: with concurrent.futures.ThreadPoolExecutor(max_workers=pass_workers) as pool: futures = [pool.submit(ssh_probe, ip, users, key_path, pass_timeout) for ip in scan_ips] for fut in concurrent.futures.as_completed(futures): result = fut.result() if not result: continue host, ip, serial = result if host not in seen_hostnames: seen_hostnames[host] = ip if ip not in seen_ips: seen_ips[ip] = (host, serial) target = None if serial in vmid_to_name: inferred = vmid_to_name[serial] target = inferred elif host in target_names: target = host if target: existing = found.get(target) if existing is None or (is_vip_ip(existing) and not is_vip_ip(ip)): found[target] = ip if all(name in found for name in target_names): return run_pass(timeout_sec, max_workers) if not all(name in found for name in target_names): # Slower second pass for busy runners/networks. run_pass(max(timeout_sec + 2, 8), max(8, max_workers // 2)) # Heuristic fallback: if nodes still missing, assign from remaining SSH-reachable # IPs not already used, ordered by IP. This helps when cloned nodes temporarily # share a generic hostname (e.g. "flex") and DMI serial mapping is unavailable. missing = sorted([n for n in target_names if n not in found]) if missing: used_ips = set(found.values()) candidates = sorted(ip for ip in seen_ips.keys() if ip not in used_ips) if len(candidates) >= len(missing): for name, ip in zip(missing, candidates): found[name] = ip missing = sorted([n for n in target_names if n not in found]) if missing: discovered = ", ".join(sorted(seen_hostnames.keys())[:20]) if discovered: sys.stderr.write(f"Discovered hostnames during scan: {discovered}\n") if seen_ips: sample = ", ".join(f"{ip}={meta[0]}" for ip, meta in list(sorted(seen_ips.items()))[:20]) sys.stderr.write(f"SSH-reachable IPs: {sample}\n") raise SystemExit( "Failed SSH-based IP discovery for nodes: " + ", ".join(missing) + f" (scanned {prefix}.{start}-{prefix}.{end})" ) sys.stdout.write(build_inventory(target_names, found, ssh_user)) return 0 if __name__ == "__main__": raise SystemExit(main())