Skip to content
Closed
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mock
mypy
pre-commit
pytest
pytest-bdd
pytest-cov
pytest-mock
types-pyyaml
Expand Down
42 changes: 42 additions & 0 deletions task_processing/plugins/kubernetes/kube_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from http import HTTPStatus
from typing import List
from typing import Optional

from kubernetes import client as kube_client
Expand Down Expand Up @@ -176,6 +177,11 @@ def create_pod(
def get_pod(
self, namespace: str, pod_name: str, attempts: int = DEFAULT_ATTEMPTS,
) -> Optional[V1Pod]:
"""
Wrapper around read_namespaced_pod() in the kubernetes clientlib that adds in
retrying on ApiExceptions.
Returns V1Pod on success, None otherwise.
"""
max_attempts = attempts
while attempts:
try:
Expand All @@ -201,3 +207,39 @@ def get_pod(
raise
logger.info(f"Ran out of retries attempting to fetch pod {pod_name}.")
raise ExceededMaxAttempts(f'Retried fetching pod {pod_name} {max_attempts} times.')

def get_pods(
self, namespace: str, attempts: int = DEFAULT_ATTEMPTS,
) -> Optional[List[V1Pod]]:
"""
Wrapper around list_namespaced_pod() in the kubernetes clientlib that adds in
retrying on ApiExceptions.
Returns a list of V1Pod on success, None otherwise.
"""
max_attempts = attempts
while attempts:
try:
pods = self.core.list_namespaced_pod(
namespace=namespace,
).items
return pods
except ApiException as e:
# Unknown pods throws ApiException w/ 404
if e.status == 404:
logger.info(f"Found no pods in the namespace {namespace}.")
return None
if not self.maybe_reload_on_exception(exception=e) and attempts:
logger.debug(
f"Failed to fetch pods in {namespace} due to unhandled API exception, "
"retrying.",
exc_info=True
)
attempts -= 1
except Exception:
logger.exception(
f"Failed to fetch pods in {namespace} due to unhandled exception."
)
raise
logger.info(f"Ran out of retries attempting to fetch pods in namespace {namespace}.")
raise ExceededMaxAttempts(
f'Retried fetching pods in namespace {namespace} {max_attempts} times.')
132 changes: 122 additions & 10 deletions task_processing/plugins/kubernetes/kubernetes_pod_executor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import queue
import threading
import time
from queue import Empty
from queue import Queue
from time import sleep
from typing import Collection
from typing import List
from typing import Optional
from typing import Tuple

from kubernetes import watch
from kubernetes.client import V1Affinity
Expand Down Expand Up @@ -37,7 +40,6 @@
from task_processing.plugins.kubernetes.utils import get_pod_empty_volumes
from task_processing.plugins.kubernetes.utils import get_pod_volumes
from task_processing.plugins.kubernetes.utils import get_sanitised_kubernetes_name

logger = logging.getLogger(__name__)

POD_WATCH_THREAD_JOIN_TIMEOUT_S = 1.0
Expand All @@ -61,7 +63,9 @@ def __init__(
kubeconfig_path: Optional[str] = None,
task_configs: Optional[Collection[KubernetesTaskConfig]] = [],
emit_events_without_state_transitions: bool = False,

refresh_reconciliation_thread_grace: int = 300,
refresh_reconciliation_thread_interval: int = 120,
enable_reconciliation: bool = False
) -> None:
if not version:
version = "unknown_task_processing"
Expand Down Expand Up @@ -109,6 +113,18 @@ def __init__(
)
self.pending_event_processing_thread.start()

# Seconds to wait before starting reconciliation
self.refresh_reconciliation_thread_grace = refresh_reconciliation_thread_grace
self.refresh_reconciliation_thread_interval = refresh_reconciliation_thread_interval
self.enable_reconciliation = enable_reconciliation

if self.enable_reconciliation:
self.reconciliation_task_thread = threading.Thread(
target=self._reconcile_task_loop,
daemon=True,
)
self.reconciliation_task_thread.start()

def _initialize_existing_task(self, task_config: KubernetesTaskConfig) -> None:
""" Generates task_metadata in UNKNOWN state for an existing KubernetesTaskConfig.
Used during initialization or recovery for a task"""
Expand Down Expand Up @@ -171,6 +187,65 @@ def _pod_event_watch_loop(self) -> None:
"Exception encountered while watching Pod events - restarting watch!")
logger.debug("Exiting Pod event watcher - stop requested.")

def _group_pod_task_configs(
self,
pods: List[V1Pod]
) -> List[Tuple[KubernetesTaskConfig, Optional[V1Pod]]]:
"""
Called during reconciliation task loop in order to filter the task_configs/pods
that are in task_metadata.
"""
pods_seen = {pod.metadata.name: pod for pod in pods}
task_config_pods = []

for pod_name, task_metadata in self.task_metadata.items():
task_config_pod = (task_metadata.task_config, pods_seen.get(pod_name))
task_config_pods.append(task_config_pod)

return task_config_pods

def _filter_task_configs_pods_to_reconcile(
self,
task_configs_pods: List[Tuple[KubernetesTaskConfig, Optional[V1Pod]]]
) -> List[Tuple[KubernetesTaskConfig, Optional[V1Pod]]]:
"""
Called during reconciliation task loop in order to filter the task_configs/pods
that are mismatched in task_metadata.
"""
task_configs_pods_to_reconcile = []
phase_to_task_state = {
"Succeeded": KubernetesTaskState.TASK_FINISHED,
"Failed": KubernetesTaskState.TASK_FAILED,
"Running": KubernetesTaskState.TASK_RUNNING,
"Pending": KubernetesTaskState.TASK_PENDING,
"Unknown": KubernetesTaskState.TASK_LOST
}

for task_config, pod in task_configs_pods:
# If the pod is not returned then we must set the task_metadata to lost state
# if k8s no longer knows about this pod/task, then there's no way for us to
# figure out it's final state in an automated fashion
if pod is None:
task_configs_pods_to_reconcile.append((task_config, pod))
continue
pod_name = pod.metadata.name
pod_phase = pod.status.phase
task_metadata = self.task_metadata[pod_name]
task_state = task_metadata.task_state
if pod_phase not in phase_to_task_state:
logger.debug(
f"Got a MODIFIED event for {pod_name} for unhandled phase: "
f"{pod_phase} - ignoring."
)
if phase_to_task_state[pod_phase] is not task_state:
logger.debug(
f"Mismatched event found for {pod_name} during reconciliation. "
f"pod_phase: {pod_phase} - task_state: {task_state} - "
f"task_metadata: {task_metadata}"
)
task_configs_pods_to_reconcile.append((task_config, pod))
return task_configs_pods_to_reconcile

def __handle_deleted_pod_event(self, event: PodEvent) -> None:
pod = event["object"]
pod_name = pod.metadata.name
Expand Down Expand Up @@ -397,7 +472,7 @@ def _pending_event_processing_loop(self) -> None:
try:
event = self.pending_events.get(timeout=QUEUE_GET_TIMEOUT_S)
self._process_pod_event(event)
except queue.Empty:
except Empty:
logger.debug(
f"Pending event queue remained empty after {QUEUE_GET_TIMEOUT_S} seconds.",
)
Expand All @@ -420,6 +495,38 @@ def _pending_event_processing_loop(self) -> None:

logger.debug("Exiting Pod event processing - stop requested.")

def _reconcile_task_loop(self) -> None:
"""
Run in a thread to reconcile task_metadata from k8s.
"""
logger.info(f'Waiting {self.refresh_reconciliation_thread_grace}s before doing work')
sleep(self.refresh_reconciliation_thread_grace)
logger.debug("Starting Pod task config reconciliation.")
while not self.stopping:
try:
# fetch all pods in the target namespace in one request so that we
# don't block on making N serial requests to the Kubernetes API
pods = self.kube_client.get_pods(namespace=self.namespace)
except Exception:
logger.exception(
f"Hit an exception attempting to fetch pods in namespace {self.namespace}")
pods = None

if pods is not None:
# we've previously bulk-fetched all the pods running in the target
# namespace - we'll now filter these by the tasks down to only
# those we know about so that we only reconcile the state for what
# we actually need and not any other cruft that may exist.
task_configs_pods = self._group_pod_task_configs(pods)
# Filter for pods with mismatched states between K8s and task_metadata
task_configs_pods_to_reconcile = self._filter_task_configs_pods_to_reconcile(
task_configs_pods)
for task_config, pod in task_configs_pods_to_reconcile:
self.reconcile(task_config, pod)
logger.info(f'Sleeping for {self.refresh_reconciliation_thread_interval}s')
sleep(self.refresh_reconciliation_thread_interval)
logger.debug("Exiting Pod task config reconciliation - stop requested.")

def _create_container_definition(
self,
name: str,
Expand Down Expand Up @@ -544,13 +651,14 @@ def run(self, task_config: KubernetesTaskConfig) -> Optional[str]:

return None

def reconcile(self, task_config: KubernetesTaskConfig) -> None:
def reconcile(self, task_config: KubernetesTaskConfig, pod: Optional[V1Pod] = None) -> None:
pod_name = task_config.pod_name
try:
pod = self.kube_client.get_pod(namespace=self.namespace, pod_name=pod_name)
except Exception:
logger.exception(f"Hit an exception attempting to fetch pod {pod_name}")
pod = None
if pod is None:
try:
pod = self.kube_client.get_pod(namespace=self.namespace, pod_name=pod_name)
except Exception:
logger.exception(f"Hit an exception attempting to fetch pod {pod_name}")
pod = None

if pod_name not in self.task_metadata:
self._initialize_existing_task(task_config)
Expand Down Expand Up @@ -622,6 +730,10 @@ def stop(self) -> None:
# but in that case we can be reasonably sure that we're not dropping any data.
self.pod_event_watch_thread.join(timeout=POD_WATCH_THREAD_JOIN_TIMEOUT_S)

if self.enable_reconciliation:
logger.debug("Signaling reconciliation task to stop.")
self.reconciliation_task_thread.join()

logger.debug("Waiting for all pending PodEvents to be processed...")
# once we've stopped updating the pending events queue, we then wait until we're done
# processing any events we've received - this will wait until task_done() has been
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/plugins/kubernetes/kube_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,55 @@ def test_KubeClient_get_pod():
mock_kube_client.CoreV1Api().read_namespaced_pod.assert_called_once_with(
namespace='ns', name='pod-name'
)


def test_KubeClient_get_pods_too_many_failures():
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True
), mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_client",
autospec=True
) as mock_kube_client, mock.patch.dict(
os.environ, {"KUBECONFIG": "/another/kube/config.conf"}
), pytest.raises(ExceededMaxAttempts):
mock_config_path = "/OVERRIDE.conf"
mock_kube_client.CoreV1Api().list_namespaced_pod.side_effect = [ApiException, ApiException]
client = KubeClient(kubeconfig_path=mock_config_path)
client.get_pods(namespace='ns', attempts=2)
assert mock_kube_client.CoreV1Api().list_namespaced_pod.call_count == 2


def test_KubeClient_get_pods_unknown_exception():
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True
), mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_client",
autospec=True
) as mock_kube_client, mock.patch.dict(
os.environ, {"KUBECONFIG": "/another/kube/config.conf"}
), pytest.raises(Exception):
mock_config_path = "/OVERRIDE.conf"
mock_kube_client.CoreV1Api().list_namespaced_pod.side_effect = [Exception]
client = KubeClient(kubeconfig_path=mock_config_path)
client.get_pods(namespace='ns', attempts=2)


def test_KubeClient_get_pods():
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True
), mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_client",
autospec=True
) as mock_kube_client, mock.patch.dict(
os.environ, {"KUBECONFIG": "/another/kube/config.conf"}
):
mock_config_path = "/OVERRIDE.conf"
mock_kube_client.CoreV1Api().list_namespaced_pod.return_value = mock.Mock()
client = KubeClient(kubeconfig_path=mock_config_path)
client.get_pods(namespace='ns', attempts=1)
mock_kube_client.CoreV1Api().list_namespaced_pod.assert_called_once_with(
namespace='ns'
)
Loading