Files
TerraHome/nixos/kubeadm/bootstrap/controller.py
MichaelFisher1997 a70de061b0
All checks were successful
Terraform Plan / Terraform Plan (push) Successful in 18s
fix: wait for Cilium and node readiness before marking bootstrap success
Update verification stage to block on cilium daemonset rollout and all nodes reaching Ready. This prevents workflows from reporting success while the cluster is still NotReady immediately after join.
2026-03-04 22:26:43 +00:00

459 lines
17 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import base64
import json
import os
import shlex
import subprocess
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
REMOTE_STATE_PATH = "/var/lib/terrahome/bootstrap-state.json"
def run_local(cmd, check=True, capture=False):
if isinstance(cmd, str):
shell = True
else:
shell = False
return subprocess.run(
cmd,
shell=shell,
check=check,
text=True,
capture_output=capture,
)
def load_inventory(inventory_file):
inventory_file = Path(inventory_file).resolve()
if not inventory_file.exists():
raise RuntimeError(f"Missing inventory file: {inventory_file}")
cmd = (
"set -a; "
f"source {shlex.quote(str(inventory_file))}; "
"python3 - <<'PY'\n"
"import json, os\n"
"print(json.dumps(dict(os.environ)))\n"
"PY"
)
proc = run_local(["bash", "-lc", cmd], capture=True)
env = json.loads(proc.stdout)
node_ips = {}
cp_names = []
wk_names = []
control_planes = env.get("CONTROL_PLANES", "").strip()
workers = env.get("WORKERS", "").strip()
if control_planes:
for pair in control_planes.split():
name, ip = pair.split("=", 1)
node_ips[name] = ip
cp_names.append(name)
else:
for key in sorted(k for k in env if k.startswith("CP_") and k[3:].isdigit()):
idx = key.split("_", 1)[1]
name = f"cp-{idx}"
node_ips[name] = env[key]
cp_names.append(name)
if workers:
for pair in workers.split():
name, ip = pair.split("=", 1)
node_ips[name] = ip
wk_names.append(name)
else:
for key in sorted(k for k in env if k.startswith("WK_") and k[3:].isdigit()):
idx = key.split("_", 1)[1]
name = f"wk-{idx}"
node_ips[name] = env[key]
wk_names.append(name)
if not cp_names or not wk_names:
raise RuntimeError("Inventory must include control planes and workers")
primary_cp = env.get("PRIMARY_CONTROL_PLANE", "cp-1")
if primary_cp not in node_ips:
primary_cp = cp_names[0]
return {
"env": env,
"node_ips": node_ips,
"cp_names": cp_names,
"wk_names": wk_names,
"primary_cp": primary_cp,
"inventory_file": str(inventory_file),
}
class Controller:
def __init__(self, cfg):
self.env = cfg["env"]
self.node_ips = cfg["node_ips"]
self.cp_names = cfg["cp_names"]
self.wk_names = cfg["wk_names"]
self.primary_cp = cfg["primary_cp"]
self.primary_ip = self.node_ips[self.primary_cp]
self.script_dir = Path(__file__).resolve().parent
self.flake_dir = Path(self.env.get("FLAKE_DIR") or (self.script_dir.parent)).resolve()
self.local_state_path = self.script_dir / "bootstrap-state-last.json"
self.ssh_user = self.env.get("SSH_USER", "micqdf")
self.ssh_candidates = self.env.get("SSH_USER_CANDIDATES", f"root {self.ssh_user}").split()
self.active_ssh_user = self.ssh_user
self.ssh_key = self.env.get("SSH_KEY_PATH", str(Path.home() / ".ssh" / "id_ed25519"))
self.ssh_opts = [
"-o",
"BatchMode=yes",
"-o",
"IdentitiesOnly=yes",
"-o",
"StrictHostKeyChecking=accept-new",
"-i",
self.ssh_key,
]
self.rebuild_timeout = self.env.get("REBUILD_TIMEOUT", "45m")
self.rebuild_retries = int(self.env.get("REBUILD_RETRIES", "2"))
self.worker_parallelism = int(self.env.get("WORKER_PARALLELISM", "3"))
self.fast_mode = self.env.get("FAST_MODE", "1")
self.skip_rebuild = self.env.get("SKIP_REBUILD", "0") == "1"
self.force_reinit = False
def log(self, msg):
print(f"==> {msg}")
def _ssh(self, user, ip, cmd, check=True):
full = ["ssh", *self.ssh_opts, f"{user}@{ip}", f"bash -lc {shlex.quote(cmd)}"]
return run_local(full, check=check, capture=True)
def detect_user(self, ip):
for user in self.ssh_candidates:
proc = self._ssh(user, ip, "true", check=False)
if proc.returncode == 0:
self.active_ssh_user = user
self.log(f"Using SSH user '{user}' for {ip}")
return
raise RuntimeError(f"Unable to authenticate to {ip} with users: {', '.join(self.ssh_candidates)}")
def remote(self, ip, cmd, check=True):
ordered = [self.active_ssh_user] + [u for u in self.ssh_candidates if u != self.active_ssh_user]
last = None
for user in ordered:
proc = self._ssh(user, ip, cmd, check=False)
if proc.returncode == 0:
self.active_ssh_user = user
return proc
if proc.returncode != 255:
last = proc
break
last = proc
if check:
stdout = (last.stdout or "").strip()
stderr = (last.stderr or "").strip()
raise RuntimeError(f"Remote command failed on {ip}: {cmd}\n{stdout}\n{stderr}")
return last
def prepare_known_hosts(self):
ssh_dir = Path.home() / ".ssh"
ssh_dir.mkdir(parents=True, exist_ok=True)
(ssh_dir / "known_hosts").touch()
run_local(["chmod", "700", str(ssh_dir)])
run_local(["chmod", "600", str(ssh_dir / "known_hosts")])
for ip in self.node_ips.values():
run_local(["ssh-keygen", "-R", ip], check=False)
run_local(f"ssh-keyscan -H {shlex.quote(ip)} >> {shlex.quote(str(ssh_dir / 'known_hosts'))}", check=False)
def get_state(self):
proc = self.remote(
self.primary_ip,
"sudo test -f /var/lib/terrahome/bootstrap-state.json && sudo cat /var/lib/terrahome/bootstrap-state.json || echo '{}'",
)
try:
state = json.loads(proc.stdout.strip() or "{}")
except Exception:
state = {}
return state
def set_state(self, state):
payload = json.dumps(state, sort_keys=True)
b64 = base64.b64encode(payload.encode()).decode()
self.remote(
self.primary_ip,
(
"sudo mkdir -p /var/lib/terrahome && "
f"echo {shlex.quote(b64)} | base64 -d | sudo tee {REMOTE_STATE_PATH} >/dev/null"
),
)
self.local_state_path.write_text(payload + "\n", encoding="utf-8")
def mark_done(self, key):
state = self.get_state()
state[key] = True
state["updated_at"] = int(time.time())
self.set_state(state)
def clear_done(self, keys):
state = self.get_state()
for key in keys:
state.pop(key, None)
state["updated_at"] = int(time.time())
self.set_state(state)
def stage_done(self, key):
return bool(self.get_state().get(key))
def prepare_remote_nix(self, ip):
self.remote(ip, "sudo mkdir -p /etc/nix")
self.remote(ip, "if [ -f /etc/nix/nix.conf ]; then sudo sed -i '/^trusted-users[[:space:]]*=/d' /etc/nix/nix.conf; fi")
self.remote(ip, "echo 'trusted-users = root micqdf' | sudo tee -a /etc/nix/nix.conf >/dev/null")
self.remote(ip, "sudo systemctl restart nix-daemon 2>/dev/null || true")
def prepare_remote_kubelet(self, ip):
self.remote(ip, "sudo systemctl stop kubelet >/dev/null 2>&1 || true")
self.remote(ip, "sudo systemctl disable kubelet >/dev/null 2>&1 || true")
self.remote(ip, "sudo systemctl mask kubelet >/dev/null 2>&1 || true")
self.remote(ip, "sudo systemctl reset-failed kubelet >/dev/null 2>&1 || true")
self.remote(ip, "sudo rm -f /var/lib/kubelet/config.yaml /var/lib/kubelet/kubeadm-flags.env || true")
def prepare_remote_space(self, ip):
self.remote(ip, "sudo nix-collect-garbage -d || true")
self.remote(ip, "sudo nix --extra-experimental-features nix-command store gc || true")
self.remote(ip, "sudo rm -rf /tmp/nix* /tmp/nixos-rebuild* || true")
def rebuild_node_once(self, name, ip):
self.detect_user(ip)
cmd = [
"timeout",
self.rebuild_timeout,
"nixos-rebuild",
"switch",
"--flake",
f"{self.flake_dir}#{name}",
"--target-host",
f"{self.active_ssh_user}@{ip}",
"--use-remote-sudo",
]
env = os.environ.copy()
env["NIX_SSHOPTS"] = " ".join(self.ssh_opts)
proc = subprocess.run(cmd, text=True, env=env)
return proc.returncode == 0
def rebuild_with_retry(self, name, ip):
max_attempts = self.rebuild_retries + 1
for attempt in range(1, max_attempts + 1):
self.log(f"Rebuild attempt {attempt}/{max_attempts} for {name}")
if self.rebuild_node_once(name, ip):
return
if attempt < max_attempts:
self.log(f"Rebuild failed for {name}, retrying in 20s")
time.sleep(20)
raise RuntimeError(f"Rebuild failed permanently for {name}")
def stage_preflight(self):
if self.stage_done("preflight_done"):
self.log("Preflight already complete")
return
self.prepare_known_hosts()
self.detect_user(self.primary_ip)
self.mark_done("preflight_done")
def stage_rebuild(self):
if self.skip_rebuild and self.stage_done("nodes_rebuilt"):
self.log("Node rebuild already complete")
return
self.detect_user(self.primary_ip)
for name in self.cp_names:
ip = self.node_ips[name]
self.log(f"Preparing and rebuilding {name} ({ip})")
self.prepare_remote_nix(ip)
self.prepare_remote_kubelet(ip)
if self.fast_mode != "1":
self.prepare_remote_space(ip)
self.rebuild_with_retry(name, ip)
for name in self.wk_names:
ip = self.node_ips[name]
self.log(f"Preparing {name} ({ip})")
self.prepare_remote_nix(ip)
self.prepare_remote_kubelet(ip)
if self.fast_mode != "1":
self.prepare_remote_space(ip)
failures = []
with ThreadPoolExecutor(max_workers=self.worker_parallelism) as pool:
futures = {pool.submit(self.rebuild_with_retry, name, self.node_ips[name]): name for name in self.wk_names}
for fut in as_completed(futures):
name = futures[fut]
try:
fut.result()
except Exception as exc:
failures.append((name, str(exc)))
if failures:
raise RuntimeError(f"Worker rebuild failures: {failures}")
# Rebuild can invalidate prior bootstrap stages; force reconciliation.
self.force_reinit = True
self.clear_done([
"primary_initialized",
"cni_installed",
"control_planes_joined",
"workers_joined",
"verified",
])
self.mark_done("nodes_rebuilt")
def has_admin_conf(self):
return self.remote(self.primary_ip, "sudo test -f /etc/kubernetes/admin.conf", check=False).returncode == 0
def cluster_ready(self):
cmd = "sudo test -f /etc/kubernetes/admin.conf && sudo kubectl --kubeconfig /etc/kubernetes/admin.conf get --raw=/readyz >/dev/null 2>&1"
return self.remote(self.primary_ip, cmd, check=False).returncode == 0
def stage_init_primary(self):
if (not self.force_reinit) and self.stage_done("primary_initialized") and self.has_admin_conf() and self.cluster_ready():
self.log("Primary control plane init already complete")
return
if (not self.force_reinit) and self.has_admin_conf() and self.cluster_ready():
self.log("Existing cluster detected on primary control plane")
else:
self.log(f"Initializing primary control plane on {self.primary_cp}")
self.remote(self.primary_ip, "sudo th-kubeadm-init")
self.mark_done("primary_initialized")
def stage_install_cni(self):
if self.stage_done("cni_installed") and self.cluster_ready():
self.log("CNI install already complete")
return
self.log("Installing or upgrading Cilium")
self.remote(self.primary_ip, "sudo helm repo add cilium https://helm.cilium.io >/dev/null 2>&1 || true")
self.remote(self.primary_ip, "sudo helm repo update >/dev/null")
self.remote(self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf create namespace kube-system >/dev/null 2>&1 || true")
self.remote(
self.primary_ip,
"sudo KUBECONFIG=/etc/kubernetes/admin.conf helm upgrade --install cilium cilium/cilium --namespace kube-system --set kubeProxyReplacement=true",
)
self.mark_done("cni_installed")
def cluster_has_node(self, name):
cmd = f"sudo kubectl --kubeconfig /etc/kubernetes/admin.conf get node {shlex.quote(name)} >/dev/null 2>&1"
return self.remote(self.primary_ip, cmd, check=False).returncode == 0
def build_join_cmds(self):
if not self.has_admin_conf():
self.remote(self.primary_ip, "sudo th-kubeadm-init")
join_cmd = self.remote(
self.primary_ip,
"sudo KUBECONFIG=/etc/kubernetes/admin.conf kubeadm token create --print-join-command",
).stdout.strip()
cert_key = self.remote(
self.primary_ip,
"sudo KUBECONFIG=/etc/kubernetes/admin.conf kubeadm init phase upload-certs --upload-certs | tail -n 1",
).stdout.strip()
cp_join = f"{join_cmd} --control-plane --certificate-key {cert_key}"
return join_cmd, cp_join
def stage_join_control_planes(self):
if self.stage_done("control_planes_joined"):
self.log("Control-plane join already complete")
return
_, cp_join = self.build_join_cmds()
for node in self.cp_names:
if node == self.primary_cp:
continue
if self.cluster_has_node(node):
self.log(f"{node} already joined")
continue
self.log(f"Joining control plane {node}")
ip = self.node_ips[node]
node_join = f"{cp_join} --node-name {node} --ignore-preflight-errors=NumCPU,HTTPProxyCIDR"
self.remote(ip, f"sudo th-kubeadm-join-control-plane {shlex.quote(node_join)}")
self.mark_done("control_planes_joined")
def stage_join_workers(self):
if self.stage_done("workers_joined"):
self.log("Worker join already complete")
return
join_cmd, _ = self.build_join_cmds()
for node in self.wk_names:
if self.cluster_has_node(node):
self.log(f"{node} already joined")
continue
self.log(f"Joining worker {node}")
ip = self.node_ips[node]
node_join = f"{join_cmd} --node-name {node} --ignore-preflight-errors=HTTPProxyCIDR"
self.remote(ip, f"sudo th-kubeadm-join-worker {shlex.quote(node_join)}")
self.mark_done("workers_joined")
def stage_verify(self):
if self.stage_done("verified"):
self.log("Verification already complete")
return
self.log("Final node verification")
self.remote(
self.primary_ip,
"sudo kubectl --kubeconfig /etc/kubernetes/admin.conf -n kube-system rollout status ds/cilium --timeout=10m",
)
self.remote(
self.primary_ip,
"sudo kubectl --kubeconfig /etc/kubernetes/admin.conf wait --for=condition=Ready nodes --all --timeout=10m",
)
proc = self.remote(self.primary_ip, "sudo kubectl --kubeconfig /etc/kubernetes/admin.conf get nodes -o wide")
print(proc.stdout)
self.mark_done("verified")
def reconcile(self):
self.stage_preflight()
self.stage_rebuild()
self.stage_init_primary()
self.stage_install_cni()
self.stage_join_control_planes()
self.stage_join_workers()
self.stage_verify()
def main():
parser = argparse.ArgumentParser(description="TerraHome kubeadm bootstrap controller")
parser.add_argument("command", choices=[
"reconcile",
"preflight",
"rebuild",
"init-primary",
"install-cni",
"join-control-planes",
"join-workers",
"verify",
])
parser.add_argument("--inventory", default=str(Path(__file__).resolve().parent.parent / "scripts" / "inventory.env"))
args = parser.parse_args()
cfg = load_inventory(args.inventory)
ctl = Controller(cfg)
dispatch = {
"reconcile": ctl.reconcile,
"preflight": ctl.stage_preflight,
"rebuild": ctl.stage_rebuild,
"init-primary": ctl.stage_init_primary,
"install-cni": ctl.stage_install_cni,
"join-control-planes": ctl.stage_join_control_planes,
"join-workers": ctl.stage_join_workers,
"verify": ctl.stage_verify,
}
try:
dispatch[args.command]()
except Exception as exc:
print(f"ERROR: {exc}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()