#!/usr/bin/env python3 import argparse import base64 import json import os import shlex import subprocess import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path REMOTE_STATE_PATH = "/var/lib/terrahome/bootstrap-state.json" def run_local(cmd, check=True, capture=False): if isinstance(cmd, str): shell = True else: shell = False return subprocess.run( cmd, shell=shell, check=check, text=True, capture_output=capture, ) def load_inventory(inventory_file): inventory_file = Path(inventory_file).resolve() if not inventory_file.exists(): raise RuntimeError(f"Missing inventory file: {inventory_file}") cmd = ( "set -a; " f"source {shlex.quote(str(inventory_file))}; " "python3 - <<'PY'\n" "import json, os\n" "print(json.dumps(dict(os.environ)))\n" "PY" ) proc = run_local(["bash", "-lc", cmd], capture=True) env = json.loads(proc.stdout) node_ips = {} cp_names = [] wk_names = [] control_planes = env.get("CONTROL_PLANES", "").strip() workers = env.get("WORKERS", "").strip() if control_planes: for pair in control_planes.split(): name, ip = pair.split("=", 1) node_ips[name] = ip cp_names.append(name) else: for key in sorted(k for k in env if k.startswith("CP_") and k[3:].isdigit()): idx = key.split("_", 1)[1] name = f"cp-{idx}" node_ips[name] = env[key] cp_names.append(name) if workers: for pair in workers.split(): name, ip = pair.split("=", 1) node_ips[name] = ip wk_names.append(name) else: for key in sorted(k for k in env if k.startswith("WK_") and k[3:].isdigit()): idx = key.split("_", 1)[1] name = f"wk-{idx}" node_ips[name] = env[key] wk_names.append(name) if not cp_names or not wk_names: raise RuntimeError("Inventory must include control planes and workers") primary_cp = env.get("PRIMARY_CONTROL_PLANE", "cp-1") if primary_cp not in node_ips: primary_cp = cp_names[0] return { "env": env, "node_ips": node_ips, "cp_names": cp_names, "wk_names": wk_names, "primary_cp": primary_cp, "inventory_file": str(inventory_file), } class Controller: def __init__(self, cfg): self.env = cfg["env"] self.node_ips = cfg["node_ips"] self.cp_names = cfg["cp_names"] self.wk_names = cfg["wk_names"] self.primary_cp = cfg["primary_cp"] self.primary_ip = self.node_ips[self.primary_cp] self.script_dir = Path(__file__).resolve().parent self.flake_dir = Path(self.env.get("FLAKE_DIR") or (self.script_dir.parent)).resolve() self.local_state_path = self.script_dir / "bootstrap-state-last.json" self.ssh_user = self.env.get("SSH_USER", "micqdf") self.ssh_candidates = self.env.get("SSH_USER_CANDIDATES", f"root {self.ssh_user}").split() self.active_ssh_user = self.ssh_user self.ssh_key = self.env.get("SSH_KEY_PATH", str(Path.home() / ".ssh" / "id_ed25519")) self.ssh_opts = [ "-o", "BatchMode=yes", "-o", "IdentitiesOnly=yes", "-o", "StrictHostKeyChecking=accept-new", "-i", self.ssh_key, ] self.rebuild_timeout = self.env.get("REBUILD_TIMEOUT", "45m") self.rebuild_retries = int(self.env.get("REBUILD_RETRIES", "2")) self.worker_parallelism = int(self.env.get("WORKER_PARALLELISM", "3")) self.fast_mode = self.env.get("FAST_MODE", "1") self.skip_rebuild = self.env.get("SKIP_REBUILD", "0") == "1" def log(self, msg): print(f"==> {msg}") def _ssh(self, user, ip, cmd, check=True): full = ["ssh", *self.ssh_opts, f"{user}@{ip}", f"bash -lc {shlex.quote(cmd)}"] return run_local(full, check=check, capture=True) def detect_user(self, ip): for user in self.ssh_candidates: proc = self._ssh(user, ip, "true", check=False) if proc.returncode == 0: self.active_ssh_user = user self.log(f"Using SSH user '{user}' for {ip}") return raise RuntimeError(f"Unable to authenticate to {ip} with users: {', '.join(self.ssh_candidates)}") def remote(self, ip, cmd, check=True): ordered = [self.active_ssh_user] + [u for u in self.ssh_candidates if u != self.active_ssh_user] last = None for user in ordered: proc = self._ssh(user, ip, cmd, check=False) if proc.returncode == 0: self.active_ssh_user = user return proc if proc.returncode != 255: last = proc break last = proc if check: stdout = (last.stdout or "").strip() stderr = (last.stderr or "").strip() raise RuntimeError(f"Remote command failed on {ip}: {cmd}\n{stdout}\n{stderr}") return last def prepare_known_hosts(self): ssh_dir = Path.home() / ".ssh" ssh_dir.mkdir(parents=True, exist_ok=True) (ssh_dir / "known_hosts").touch() run_local(["chmod", "700", str(ssh_dir)]) run_local(["chmod", "600", str(ssh_dir / "known_hosts")]) for ip in self.node_ips.values(): run_local(["ssh-keygen", "-R", ip], check=False) run_local(f"ssh-keyscan -H {shlex.quote(ip)} >> {shlex.quote(str(ssh_dir / 'known_hosts'))}", check=False) def get_state(self): proc = self.remote( self.primary_ip, "sudo test -f /var/lib/terrahome/bootstrap-state.json && sudo cat /var/lib/terrahome/bootstrap-state.json || echo '{}'", ) try: state = json.loads(proc.stdout.strip() or "{}") except Exception: state = {} return state def set_state(self, state): payload = json.dumps(state, sort_keys=True) b64 = base64.b64encode(payload.encode()).decode() self.remote( self.primary_ip, ( "sudo mkdir -p /var/lib/terrahome && " f"echo {shlex.quote(b64)} | base64 -d | sudo tee {REMOTE_STATE_PATH} >/dev/null" ), ) self.local_state_path.write_text(payload + "\n", encoding="utf-8") def mark_done(self, key): state = self.get_state() state[key] = True state["updated_at"] = int(time.time()) self.set_state(state) def clear_done(self, keys): state = self.get_state() for key in keys: state.pop(key, None) state["updated_at"] = int(time.time()) self.set_state(state) def stage_done(self, key): return bool(self.get_state().get(key)) def prepare_remote_nix(self, ip): self.remote(ip, "sudo mkdir -p /etc/nix") self.remote(ip, "if [ -f /etc/nix/nix.conf ]; then sudo sed -i '/^trusted-users[[:space:]]*=/d' /etc/nix/nix.conf; fi") self.remote(ip, "echo 'trusted-users = root micqdf' | sudo tee -a /etc/nix/nix.conf >/dev/null") self.remote(ip, "sudo systemctl restart nix-daemon 2>/dev/null || true") def prepare_remote_kubelet(self, ip): self.remote(ip, "sudo systemctl stop kubelet >/dev/null 2>&1 || true") self.remote(ip, "sudo systemctl disable kubelet >/dev/null 2>&1 || true") self.remote(ip, "sudo systemctl mask kubelet >/dev/null 2>&1 || true") self.remote(ip, "sudo systemctl reset-failed kubelet >/dev/null 2>&1 || true") self.remote(ip, "sudo rm -f /var/lib/kubelet/config.yaml /var/lib/kubelet/kubeadm-flags.env || true") def prepare_remote_space(self, ip): self.remote(ip, "sudo nix-collect-garbage -d || true") self.remote(ip, "sudo nix --extra-experimental-features nix-command store gc || true") self.remote(ip, "sudo rm -rf /tmp/nix* /tmp/nixos-rebuild* || true") def rebuild_node_once(self, name, ip): self.detect_user(ip) cmd = [ "timeout", self.rebuild_timeout, "nixos-rebuild", "switch", "--flake", f"{self.flake_dir}#{name}", "--target-host", f"{self.active_ssh_user}@{ip}", "--use-remote-sudo", ] env = os.environ.copy() env["NIX_SSHOPTS"] = " ".join(self.ssh_opts) proc = subprocess.run(cmd, text=True, env=env) return proc.returncode == 0 def rebuild_with_retry(self, name, ip): max_attempts = self.rebuild_retries + 1 for attempt in range(1, max_attempts + 1): self.log(f"Rebuild attempt {attempt}/{max_attempts} for {name}") if self.rebuild_node_once(name, ip): return if attempt < max_attempts: self.log(f"Rebuild failed for {name}, retrying in 20s") time.sleep(20) raise RuntimeError(f"Rebuild failed permanently for {name}") def stage_preflight(self): if self.stage_done("preflight_done"): self.log("Preflight already complete") return self.prepare_known_hosts() self.detect_user(self.primary_ip) self.mark_done("preflight_done") def stage_rebuild(self): if self.skip_rebuild and self.stage_done("nodes_rebuilt"): self.log("Node rebuild already complete") return self.detect_user(self.primary_ip) for name in self.cp_names: ip = self.node_ips[name] self.log(f"Preparing and rebuilding {name} ({ip})") self.prepare_remote_nix(ip) self.prepare_remote_kubelet(ip) if self.fast_mode != "1": self.prepare_remote_space(ip) self.rebuild_with_retry(name, ip) for name in self.wk_names: ip = self.node_ips[name] self.log(f"Preparing {name} ({ip})") self.prepare_remote_nix(ip) self.prepare_remote_kubelet(ip) if self.fast_mode != "1": self.prepare_remote_space(ip) failures = [] with ThreadPoolExecutor(max_workers=self.worker_parallelism) as pool: futures = {pool.submit(self.rebuild_with_retry, name, self.node_ips[name]): name for name in self.wk_names} for fut in as_completed(futures): name = futures[fut] try: fut.result() except Exception as exc: failures.append((name, str(exc))) if failures: raise RuntimeError(f"Worker rebuild failures: {failures}") # Rebuild can invalidate prior bootstrap stages; force reconciliation. self.clear_done([ "primary_initialized", "cni_installed", "control_planes_joined", "workers_joined", "verified", ]) self.mark_done("nodes_rebuilt") def has_admin_conf(self): return self.remote(self.primary_ip, "sudo test -f /etc/kubernetes/admin.conf", check=False).returncode == 0 def cluster_ready(self): cmd = "sudo test -f /etc/kubernetes/admin.conf && sudo kubectl --kubeconfig /etc/kubernetes/admin.conf get --raw=/readyz >/dev/null 2>&1" return self.remote(self.primary_ip, cmd, check=False).returncode == 0 def stage_init_primary(self): if self.stage_done("primary_initialized") and self.has_admin_conf() and self.cluster_ready(): self.log("Primary control plane init already complete") return if self.has_admin_conf() and self.cluster_ready(): self.log("Existing cluster detected on primary control plane") else: self.log(f"Initializing primary control plane on {self.primary_cp}") self.remote(self.primary_ip, "sudo th-kubeadm-init") self.mark_done("primary_initialized") def stage_install_cni(self): if self.stage_done("cni_installed") and self.cluster_ready(): self.log("CNI install already complete") return self.log("Installing or upgrading Cilium") self.remote(self.primary_ip, "sudo helm repo add cilium https://helm.cilium.io >/dev/null 2>&1 || true") self.remote(self.primary_ip, "sudo helm repo update >/dev/null") self.remote(self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf create namespace kube-system >/dev/null 2>&1 || true") self.remote( self.primary_ip, "sudo KUBECONFIG=/etc/kubernetes/admin.conf helm upgrade --install cilium cilium/cilium --namespace kube-system --set kubeProxyReplacement=true", ) self.mark_done("cni_installed") def cluster_has_node(self, name): cmd = f"sudo kubectl --kubeconfig /etc/kubernetes/admin.conf get node {shlex.quote(name)} >/dev/null 2>&1" return self.remote(self.primary_ip, cmd, check=False).returncode == 0 def build_join_cmds(self): if not self.has_admin_conf(): self.remote(self.primary_ip, "sudo th-kubeadm-init") join_cmd = self.remote( self.primary_ip, "sudo KUBECONFIG=/etc/kubernetes/admin.conf kubeadm token create --print-join-command", ).stdout.strip() cert_key = self.remote( self.primary_ip, "sudo KUBECONFIG=/etc/kubernetes/admin.conf kubeadm init phase upload-certs --upload-certs | tail -n 1", ).stdout.strip() cp_join = f"{join_cmd} --control-plane --certificate-key {cert_key}" return join_cmd, cp_join def stage_join_control_planes(self): if self.stage_done("control_planes_joined"): self.log("Control-plane join already complete") return _, cp_join = self.build_join_cmds() for node in self.cp_names: if node == self.primary_cp: continue if self.cluster_has_node(node): self.log(f"{node} already joined") continue self.log(f"Joining control plane {node}") ip = self.node_ips[node] node_join = f"{cp_join} --node-name {node}" self.remote(ip, f"sudo th-kubeadm-join-control-plane {shlex.quote(node_join)}") self.mark_done("control_planes_joined") def stage_join_workers(self): if self.stage_done("workers_joined"): self.log("Worker join already complete") return join_cmd, _ = self.build_join_cmds() for node in self.wk_names: if self.cluster_has_node(node): self.log(f"{node} already joined") continue self.log(f"Joining worker {node}") ip = self.node_ips[node] node_join = f"{join_cmd} --node-name {node}" self.remote(ip, f"sudo th-kubeadm-join-worker {shlex.quote(node_join)}") self.mark_done("workers_joined") def stage_verify(self): if self.stage_done("verified"): self.log("Verification already complete") return self.log("Final node verification") proc = self.remote(self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf get nodes -o wide") print(proc.stdout) self.mark_done("verified") def reconcile(self): self.stage_preflight() self.stage_rebuild() self.stage_init_primary() self.stage_install_cni() self.stage_join_control_planes() self.stage_join_workers() self.stage_verify() def main(): parser = argparse.ArgumentParser(description="TerraHome kubeadm bootstrap controller") parser.add_argument("command", choices=[ "reconcile", "preflight", "rebuild", "init-primary", "install-cni", "join-control-planes", "join-workers", "verify", ]) parser.add_argument("--inventory", default=str(Path(__file__).resolve().parent.parent / "scripts" / "inventory.env")) args = parser.parse_args() cfg = load_inventory(args.inventory) ctl = Controller(cfg) dispatch = { "reconcile": ctl.reconcile, "preflight": ctl.stage_preflight, "rebuild": ctl.stage_rebuild, "init-primary": ctl.stage_init_primary, "install-cni": ctl.stage_install_cni, "join-control-planes": ctl.stage_join_control_planes, "join-workers": ctl.stage_join_workers, "verify": ctl.stage_verify, } try: dispatch[args.command]() except Exception as exc: print(f"ERROR: {exc}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()