#!/usr/bin/env python3 import argparse import base64 import json import os import shlex import subprocess import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path def run_local(cmd, check=True, capture=False): if isinstance(cmd, str): shell = True else: shell = False return subprocess.run( cmd, shell=shell, check=check, text=True, capture_output=capture, ) def load_inventory(inventory_file): inventory_file = Path(inventory_file).resolve() if not inventory_file.exists(): raise RuntimeError(f"Missing inventory file: {inventory_file}") cmd = ( "set -a; " f"source {shlex.quote(str(inventory_file))}; " "python3 - <<'PY'\n" "import json, os\n" "print(json.dumps(dict(os.environ)))\n" "PY" ) proc = run_local(["bash", "-lc", cmd], capture=True) env = json.loads(proc.stdout) node_ips = {} cp_names = [] wk_names = [] control_planes = env.get("CONTROL_PLANES", "").strip() workers = env.get("WORKERS", "").strip() if control_planes: for pair in control_planes.split(): name, ip = pair.split("=", 1) node_ips[name] = ip cp_names.append(name) else: for key in sorted(k for k in env if k.startswith("CP_") and k[3:].isdigit()): idx = key.split("_", 1)[1] name = f"cp-{idx}" node_ips[name] = env[key] cp_names.append(name) if workers: for pair in workers.split(): name, ip = pair.split("=", 1) node_ips[name] = ip wk_names.append(name) else: for key in sorted(k for k in env if k.startswith("WK_") and k[3:].isdigit()): idx = key.split("_", 1)[1] name = f"wk-{idx}" node_ips[name] = env[key] wk_names.append(name) if not cp_names or not wk_names: raise RuntimeError("Inventory must include control planes and workers") primary_cp = env.get("PRIMARY_CONTROL_PLANE", "cp-1") if primary_cp not in node_ips: primary_cp = cp_names[0] return { "env": env, "node_ips": node_ips, "cp_names": cp_names, "wk_names": wk_names, "primary_cp": primary_cp, "inventory_file": str(inventory_file), } class Controller: def __init__(self, cfg): self.env = cfg["env"] self.node_ips = cfg["node_ips"] self.cp_names = cfg["cp_names"] self.wk_names = cfg["wk_names"] self.primary_cp = cfg["primary_cp"] self.primary_ip = self.node_ips[self.primary_cp] self.script_dir = Path(__file__).resolve().parent self.flake_dir = Path(self.env.get("FLAKE_DIR") or (self.script_dir.parent)).resolve() self.ssh_user = self.env.get("SSH_USER", "micqdf") self.ssh_candidates = self.env.get("SSH_USER_CANDIDATES", f"root {self.ssh_user}").split() self.active_ssh_user = self.ssh_user self.ssh_key = self.env.get("SSH_KEY_PATH", str(Path.home() / ".ssh" / "id_ed25519")) self.ssh_opts = [ "-o", "BatchMode=yes", "-o", "IdentitiesOnly=yes", "-o", "StrictHostKeyChecking=accept-new", "-i", self.ssh_key, ] self.rebuild_timeout = self.env.get("REBUILD_TIMEOUT", "45m") self.rebuild_retries = int(self.env.get("REBUILD_RETRIES", "2")) self.worker_parallelism = int(self.env.get("WORKER_PARALLELISM", "3")) self.fast_mode = self.env.get("FAST_MODE", "1") self.skip_rebuild = self.env.get("SKIP_REBUILD", "0") == "1" self.force_reinit = True self.ssh_ready_retries = int(self.env.get("SSH_READY_RETRIES", "20")) self.ssh_ready_delay = int(self.env.get("SSH_READY_DELAY_SEC", "15")) def log(self, msg): print(f"==> {msg}") def _ssh(self, user, ip, cmd, check=True): full = ["ssh", *self.ssh_opts, f"{user}@{ip}", f"bash -lc {shlex.quote(cmd)}"] return run_local(full, check=check, capture=True) def detect_user(self, ip): for attempt in range(1, self.ssh_ready_retries + 1): for user in self.ssh_candidates: proc = self._ssh(user, ip, "true", check=False) if proc.returncode == 0: self.active_ssh_user = user self.log(f"Using SSH user '{user}' for {ip}") return if attempt < self.ssh_ready_retries: self.log( f"SSH not ready on {ip} yet; retrying in {self.ssh_ready_delay}s " f"({attempt}/{self.ssh_ready_retries})" ) time.sleep(self.ssh_ready_delay) raise RuntimeError(f"Unable to authenticate to {ip} with users: {', '.join(self.ssh_candidates)}") def remote(self, ip, cmd, check=True): ordered = [self.active_ssh_user] + [u for u in self.ssh_candidates if u != self.active_ssh_user] last = None for user in ordered: proc = self._ssh(user, ip, cmd, check=False) if proc.returncode == 0: self.active_ssh_user = user return proc if proc.returncode != 255: last = proc break last = proc if check: stdout = (last.stdout or "").strip() stderr = (last.stderr or "").strip() raise RuntimeError(f"Remote command failed on {ip}: {cmd}\n{stdout}\n{stderr}") return last def prepare_known_hosts(self): ssh_dir = Path.home() / ".ssh" ssh_dir.mkdir(parents=True, exist_ok=True) (ssh_dir / "known_hosts").touch() run_local(["chmod", "700", str(ssh_dir)]) run_local(["chmod", "600", str(ssh_dir / "known_hosts")]) for ip in self.node_ips.values(): run_local(["ssh-keygen", "-R", ip], check=False) run_local(f"ssh-keyscan -H {shlex.quote(ip)} >> {shlex.quote(str(ssh_dir / 'known_hosts'))}", check=False) def prepare_remote_nix(self, ip): self.remote(ip, "sudo mkdir -p /etc/nix") self.remote(ip, "if [ -f /etc/nix/nix.conf ]; then sudo sed -i '/^trusted-users[[:space:]]*=/d' /etc/nix/nix.conf; fi") self.remote(ip, "echo 'trusted-users = root micqdf' | sudo tee -a /etc/nix/nix.conf >/dev/null") self.remote(ip, "sudo systemctl restart nix-daemon 2>/dev/null || true") def prepare_remote_kubelet(self, ip): self.remote(ip, "sudo systemctl stop kubelet >/dev/null 2>&1 || true") self.remote(ip, "sudo systemctl disable kubelet >/dev/null 2>&1 || true") self.remote(ip, "sudo systemctl mask kubelet >/dev/null 2>&1 || true") self.remote(ip, "sudo systemctl reset-failed kubelet >/dev/null 2>&1 || true") self.remote(ip, "sudo rm -f /var/lib/kubelet/config.yaml /var/lib/kubelet/kubeadm-flags.env || true") def prepare_remote_space(self, ip): self.remote(ip, "sudo nix-collect-garbage -d || true") self.remote(ip, "sudo nix --extra-experimental-features nix-command store gc || true") self.remote(ip, "sudo rm -rf /tmp/nix* /tmp/nixos-rebuild* || true") def rebuild_node_once(self, name, ip): self.detect_user(ip) cmd = [ "timeout", self.rebuild_timeout, "nixos-rebuild", "switch", "--flake", f"{self.flake_dir}#{name}", "--target-host", f"{self.active_ssh_user}@{ip}", "--use-remote-sudo", ] env = os.environ.copy() env["NIX_SSHOPTS"] = " ".join(self.ssh_opts) proc = subprocess.run(cmd, text=True, env=env) return proc.returncode == 0 def rebuild_with_retry(self, name, ip): max_attempts = self.rebuild_retries + 1 for attempt in range(1, max_attempts + 1): self.log(f"Rebuild attempt {attempt}/{max_attempts} for {name}") if self.rebuild_node_once(name, ip): return if attempt < max_attempts: self.log(f"Rebuild failed for {name}, retrying in 20s") time.sleep(20) raise RuntimeError(f"Rebuild failed permanently for {name}") def stage_preflight(self): self.prepare_known_hosts() self.detect_user(self.primary_ip) def stage_rebuild(self): if self.skip_rebuild: self.log("Node rebuild already complete") return self.detect_user(self.primary_ip) for name in self.cp_names: ip = self.node_ips[name] self.log(f"Preparing and rebuilding {name} ({ip})") self.prepare_remote_nix(ip) self.prepare_remote_kubelet(ip) if self.fast_mode != "1": self.prepare_remote_space(ip) self.rebuild_with_retry(name, ip) for name in self.wk_names: ip = self.node_ips[name] self.log(f"Preparing {name} ({ip})") self.prepare_remote_nix(ip) self.prepare_remote_kubelet(ip) if self.fast_mode != "1": self.prepare_remote_space(ip) failures = [] with ThreadPoolExecutor(max_workers=self.worker_parallelism) as pool: futures = {pool.submit(self.rebuild_with_retry, name, self.node_ips[name]): name for name in self.wk_names} for fut in as_completed(futures): name = futures[fut] try: fut.result() except Exception as exc: failures.append((name, str(exc))) if failures: raise RuntimeError(f"Worker rebuild failures: {failures}") def has_admin_conf(self): return self.remote(self.primary_ip, "sudo test -f /etc/kubernetes/admin.conf", check=False).returncode == 0 def cluster_ready(self): cmd = "sudo test -f /etc/kubernetes/admin.conf && sudo kubectl --kubeconfig /etc/kubernetes/admin.conf get --raw=/readyz >/dev/null 2>&1" return self.remote(self.primary_ip, cmd, check=False).returncode == 0 def stage_init_primary(self): self.log(f"Initializing primary control plane on {self.primary_cp}") self.remote(self.primary_ip, "sudo th-kubeadm-init") def stage_install_cni(self): self.log("Installing Flannel") manifest_path = self.script_dir.parent / "manifests" / "kube-flannel.yml" manifest_b64 = base64.b64encode(manifest_path.read_bytes()).decode() self.remote( self.primary_ip, ( "sudo mkdir -p /var/lib/terrahome && " f"echo {shlex.quote(manifest_b64)} | base64 -d | sudo tee /var/lib/terrahome/kube-flannel.yml >/dev/null" ), ) self.log("Waiting for API readiness before applying Flannel") ready = False for _ in range(30): if self.cluster_ready(): ready = True break time.sleep(10) if not ready: raise RuntimeError("API server did not become ready before Flannel install") last_error = None for attempt in range(1, 6): proc = self.remote( self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf apply -f /var/lib/terrahome/kube-flannel.yml", check=False, ) if proc.returncode == 0: return last_error = (proc.stdout or "") + ("\n" if proc.stdout and proc.stderr else "") + (proc.stderr or "") self.log(f"Flannel apply attempt {attempt}/5 failed; retrying in 15s") time.sleep(15) raise RuntimeError(f"Flannel apply failed after retries\n{last_error or ''}") def cluster_has_node(self, name): cmd = f"sudo kubectl --kubeconfig /etc/kubernetes/admin.conf get node {shlex.quote(name)} >/dev/null 2>&1" return self.remote(self.primary_ip, cmd, check=False).returncode == 0 def build_join_cmds(self): join_cmd = self.remote( self.primary_ip, "sudo KUBECONFIG=/etc/kubernetes/admin.conf kubeadm token create --print-join-command", ).stdout.strip() cert_key = self.remote( self.primary_ip, "sudo KUBECONFIG=/etc/kubernetes/admin.conf kubeadm init phase upload-certs --upload-certs | tail -n 1", ).stdout.strip() cp_join = f"{join_cmd} --control-plane --certificate-key {cert_key}" return join_cmd, cp_join def stage_join_control_planes(self): _, cp_join = self.build_join_cmds() for node in self.cp_names: if node == self.primary_cp: continue if self.cluster_has_node(node): self.log(f"{node} already joined") continue self.log(f"Joining control plane {node}") ip = self.node_ips[node] node_join = f"{cp_join} --node-name {node} --ignore-preflight-errors=NumCPU,HTTPProxyCIDR" self.remote(ip, f"sudo th-kubeadm-join-control-plane {shlex.quote(node_join)}") def stage_join_workers(self): join_cmd, _ = self.build_join_cmds() for node in self.wk_names: if self.cluster_has_node(node): self.log(f"{node} already joined") continue self.log(f"Joining worker {node}") ip = self.node_ips[node] node_join = f"{join_cmd} --node-name {node} --ignore-preflight-errors=HTTPProxyCIDR" self.remote(ip, f"sudo th-kubeadm-join-worker {shlex.quote(node_join)}") def stage_verify(self): self.log("Final node verification") try: self.remote( self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel rollout status ds/kube-flannel-ds --timeout=10m", ) except Exception: self.log("Flannel rollout failed; collecting diagnostics") proc = self.remote( self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel get ds -o wide || true", check=False, ) print(proc.stdout) proc = self.remote( self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel get pods -o wide || true", check=False, ) print(proc.stdout) proc = self.remote( self.primary_ip, "for p in $(sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel get pods -o name 2>/dev/null); do echo \"--- describe $p ---\"; sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel describe $p || true; done", check=False, ) print(proc.stdout) proc = self.remote( self.primary_ip, "for p in $(sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel get pods -o name 2>/dev/null); do echo \"--- logs $p kube-flannel ---\"; sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel logs $p -c kube-flannel --tail=120 || true; echo \"--- logs $p install-cni-plugin ---\"; sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel logs $p -c install-cni-plugin --tail=120 || true; echo \"--- logs $p install-cni ---\"; sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel logs $p -c install-cni --tail=120 || true; done", check=False, ) print(proc.stdout) proc = self.remote( self.primary_ip, "for p in $(sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel get pods -o name 2>/dev/null); do sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-flannel logs --tail=120 $p || true; done", check=False, ) print(proc.stdout) raise self.remote( self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf wait --for=condition=Ready nodes --all --timeout=10m", ) proc = self.remote(self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf get nodes -o wide") print(proc.stdout) def reconcile(self): self.stage_preflight() self.stage_rebuild() self.stage_init_primary() self.stage_install_cni() self.stage_join_control_planes() self.stage_join_workers() self.stage_verify() def main(): parser = argparse.ArgumentParser(description="TerraHome kubeadm bootstrap controller") parser.add_argument("command", choices=[ "reconcile", "preflight", "rebuild", "init-primary", "install-cni", "join-control-planes", "join-workers", "verify", ]) parser.add_argument("--inventory", default=str(Path(__file__).resolve().parent.parent / "scripts" / "inventory.env")) args = parser.parse_args() cfg = load_inventory(args.inventory) ctl = Controller(cfg) dispatch = { "reconcile": ctl.reconcile, "preflight": ctl.stage_preflight, "rebuild": ctl.stage_rebuild, "init-primary": ctl.stage_init_primary, "install-cni": ctl.stage_install_cni, "join-control-planes": ctl.stage_join_control_planes, "join-workers": ctl.stage_join_workers, "verify": ctl.stage_verify, } try: dispatch[args.command]() except Exception as exc: print(f"ERROR: {exc}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()