Skip to content

Instantly share code, notes, and snippets.

@a7ul
Last active October 7, 2021 20:44
Show Gist options
  • Save a7ul/d453239dca846fe36a79b02e5ab0b177 to your computer and use it in GitHub Desktop.
Save a7ul/d453239dca846fe36a79b02e5ab0b177 to your computer and use it in GitHub Desktop.
Airflow operator to create kubernetes jobs in a separate GKE cluster via Airflow
# # Installation
# 1. Copy this gkejoboperator.py in your dag folder.
#
# 2. For this custom operator to be used in the composer env, we need to install these python modules
#
# // requirements.txt
# kubernetes==11.0.0
# pyyaml==5.3.1
#
# 3. Add a connection in airflow for the operator to use for connecting to gke cluster.
#
# In Airflow:
# Go to Admin -> Connections -> Create
#
# Enter the following details:
#
# Conn Id: my-gke-connection
# Conn Type: google cloud platform
# Project Id: mygcpproject
# Keyfile Json: Contents of the service-account.json with access to target k8s cluster
#
#
# -- Now you are ready to use the operator in your dags
#
# # -----------------------------------
# # Usage in DAG files
# # -----------------------------------
#
#
# from gkejoboperator import GKEJobOperator
#
# dag = DAG(
# f'my_dag',
# default_args=default_args,
# catchup=False,
# max_active_runs=1,
# schedule_interval="0 15 * * 1-5",
# )
#
# my_job = GKEJobOperator(
# dag=dag,
# task_id='my_task_id',
# gcp_conn_id='my-gke-connection',
# cluster_name='mygkeclustername',
# project_id='mygcpproject',
# location='europe-west1',
# job_yaml='''
# apiVersion: batch/v1
# kind: Job
# metadata:
# labels:
# app: myapp
# variant: somejob
# spec:
# ttlSecondsAfterFinished: 120
# backoffLimit: 0
# activeDeadlineSeconds: 3600
# template:
# metadata:
# labels:
# app: myapp
# variant: somejob
# spec:
# restartPolicy: "Never"
# containers:
# - name: my-pod
# image: eu.gcr.io/mygcpproject/myapp:latest
# imagePullPolicy: Always
# command:
# - "node"
# - "/app/dist/scripts/somescript.js"
# envFrom:
# - secretRef:
# name: my-secrets
# - configMapRef:
# name: my-config
# ''',
# )
import logging
import os
import json
import re
import subprocess
import tempfile
import time
import unicodedata
from typing import Optional
import yaml
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from kubernetes import client, config
from kubernetes.client.rest import ApiException as K8sApiException
from urllib3.exceptions import ProtocolError
def slugify(text):
"""
Converts to lowercase, removes non-word characters (alphanumerics and
underscores) and converts spaces to hyphens.
"""
text = unicodedata.normalize("NFKD", text).lower()
return re.sub(r"[\W_]+", "-", text)
def serialize_labels(label_obj):
labels = []
for key in label_obj:
labels.append(f"{key}={label_obj[key]}")
return ",".join(labels)
class GKEJobOperator(BaseOperator):
"""
Executes a task in a Kubernetes Job in the specified Google Kubernetes Engine cluster
This Operator assumes that the system has gcloud installed and has configured a
connection id with a service account that has access to the GKE cluster you want to connect to.
The **minimum** required to define a cluster to create are the variables
``task_id``,``gcp_conn_id``,``cluster_name``,``job_yaml``
The **optional** arguments include
``project_id``, ``location``
Note that project_id is required if it is not specified in gcp_conn_id
:param gcp_conn_id: The google cloud connection id to use. This allows for users to specify a service account.
:type gcp_conn_id: str
:param cluster_name: The name of the Google Kubernetes Engine cluster the job should be spawned in
:type cluster_name: str
:param job_yaml: The complete job yaml file to deploy on the cluster.
:type job_yaml: str
:param project_id: GCP project id
:type project_id: str
:param location: The location of the cluster like europe-west1-b
:type location: str
"""
template_fields = ["job_yaml"]
template_fields_renderers = {"job_yaml": "yaml"}
@apply_defaults
def __init__(self,
gcp_conn_id: str,
cluster_name: str,
job_yaml: str,
project_id: Optional[str] = None,
location: Optional[str] = None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
# internal
self.k8s_job_metadata = None # Assigned after job is created in k8s
self.k8s_job_labels = {"dag_id": self.dag_id, "task_id": self.task_id}
self.k8s_client = None # Assigned after `execute` is called
self.k8s_namespace: Optional[str] = None # Assigned after `execute` is called
# external
self.gcp_conn_id = gcp_conn_id
self.cluster_name = cluster_name
self.job_yaml = job_yaml
self.project_id = project_id
self.location = "europe-west1-b" if location is None else location
def execute(self, context):
job_definition = self._parse_job_yaml(self.job_yaml)
self.k8s_namespace = job_definition["metadata"]["namespace"]
self.k8s_client = self._get_k8s_client()
k8s_batch_client = self.k8s_client.BatchV1Api()
job_response = k8s_batch_client.create_namespaced_job(self.k8s_namespace, job_definition)
logging.info("Job created in K8s:")
logging.info(f"{job_response.metadata}")
self.k8s_job_metadata = job_response.metadata
job_status = self._wait_for_job_to_end()
self._get_job_logs()
self._cleanup()
if job_status is None or job_status.succeeded is None:
# Mark task as failed
raise Exception("The job failed to complete")
def on_kill(self):
"""
Called when the task is killed, either by making it as failed or success manually.
"""
logging.info("The DAG job was killed!")
self._cleanup()
return super().on_kill()
def _cleanup(self):
logging.info("Cleaning up...")
try:
k8s_batch_client = self.k8s_client.BatchV1Api()
job_name = self.k8s_job_metadata.name
delete_options_current_job = self.k8s_client.V1DeleteOptions(
propagation_policy="Foreground",
grace_period_seconds=0
)
# Delete current job
k8s_batch_client.delete_namespaced_job(name=job_name, body=delete_options_current_job,
namespace=self.k8s_namespace)
# Also cleanup any stale finished successful jobs that might exist of this dag+task.
# If any failed ones cleanup them manually!
label_selector = serialize_labels(self.k8s_job_labels)
field_selector = "status.successful=1"
k8s_batch_client.delete_collection_namespaced_job(namespace=self.k8s_namespace,
field_selector=field_selector,
grace_period_seconds=0,
propagation_policy="Foreground",
label_selector=label_selector)
except K8sApiException as e:
logging.warning(f"Error while cleaning up {e}")
def _parse_job_yaml(self, raw_yaml: str):
job_yaml = yaml.safe_load(raw_yaml)
job_yaml.setdefault("metadata", {})
# We do not want to set the job name,
# since it may cause conflicts when this operator is used across multiple dags / tasks.
# So, we use generateName instead.
job_name_prefix = job_yaml["metadata"].pop("name",
slugify(f"job-{self.dag_id}-{self.task_id}"))
job_yaml["metadata"].setdefault("generateName", f"{job_name_prefix}-")
job_yaml["metadata"].setdefault("labels", {})
job_yaml["metadata"]["labels"].update(self.k8s_job_labels)
job_yaml["metadata"].setdefault("namespace", "default")
job_yaml["metadata"].setdefault("finalizers", [])
if "foregroundDeletion" not in set(job_yaml["metadata"]["finalizers"]):
job_yaml["metadata"]["finalizers"].append("foregroundDeletion")
job_yaml["metadata"].setdefault("labels", {})
job_yaml.setdefault("spec", {})
# This is still an alpha feature hence unavailable in gke clusters by default at the moment
job_yaml["spec"].setdefault("ttlSecondsAfterFinished", 120)
job_yaml["spec"].setdefault("backoffLimit", 0)
# Default deadline for a job is 5 hours
job_yaml["spec"].setdefault("activeDeadlineSeconds", 60 * 60 * 5)
job_yaml["spec"].setdefault("template", {})
job_yaml["spec"]["template"].setdefault("metadata", {})
job_yaml["spec"]["template"]["metadata"].setdefault("labels", {})
job_yaml["spec"]["template"]["metadata"]["labels"].update(self.k8s_job_labels)
job_yaml["spec"]["template"].setdefault("spec", {})
job_yaml["spec"]["template"]["spec"].setdefault("restartPolicy", "Never")
logging.info("Job YAML:")
logging.info(json.dumps(job_yaml, indent=2))
return job_yaml
def _get_job_logs(self):
job_name = self.k8s_job_metadata.name
k8s_core_client = self.k8s_client.CoreV1Api()
job_label_selector = f"job-name={job_name}"
try:
pod_response = k8s_core_client.list_namespaced_pod(namespace=self.k8s_namespace,
label_selector=job_label_selector)
for item in pod_response.items:
pod_name = item.metadata.name
try:
# For whatever reason the response returns only the first few characters unless
# the call is for `_return_http_data_only=True, _preload_content=False`
pod_log_response = k8s_core_client.read_namespaced_pod_log(
name=pod_name,
namespace=self.k8s_namespace,
_return_http_data_only=True,
_preload_content=False,
timestamps=True
)
pod_log = pod_log_response.data.decode("utf-8")
logging.info(f"Logs for {pod_name}:")
logging.info(pod_log)
except K8sApiException:
logging.warning(f"Exception when reading log for {pod_name}")
except K8sApiException as e:
logging.warning(f"Found exception while listing pod for the job {e}")
def _wait_for_job_to_end(self):
k8s_batch_client = self.k8s_client.BatchV1Api()
job_name = self.k8s_job_metadata.name
job_status = None
logging.info("Waiting for the job to finish...")
try:
while True:
try:
job = k8s_batch_client.read_namespaced_job(namespace=self.k8s_namespace,
name=job_name)
job_status = job.status
if job.status.active is None and job.status.start_time is not None:
logging.info(f"Job status for K8s job {job_name}: {job.status}")
break
except ProtocolError:
logging.warning("Ignoring ProtocolError and Continuing...")
time.sleep(5)
except K8sApiException as e:
logging.warning(f"Error while reading status {e}")
return job_status
def _get_k8s_client(self):
gcp = GoogleCloudBaseHook(gcp_conn_id=self.gcp_conn_id)
gcp_service_account_path = gcp._get_field("key_path", False)
gcp_service_account_json = ""
if gcp_service_account_path:
with open(gcp_service_account_path) as f:
gcp_service_account_json = f.read()
else:
gcp_service_account_json = gcp._get_field("keyfile_dict", False)
self.project_id = gcp.project_id if self.project_id is None else self.project_id
with tempfile.NamedTemporaryFile("w+", suffix=".json",
encoding="utf8") as gcloud_service_account_file:
gcloud_service_account_file.write(gcp_service_account_json)
gcloud_service_account_file.seek(0)
with tempfile.NamedTemporaryFile("w+") as kube_config_file:
kube_config_file.seek(0)
custom_env = os.environ.copy()
custom_env["KUBECONFIG"] = kube_config_file.name
custom_env["GOOGLE_APPLICATION_CREDENTIALS"] = gcloud_service_account_file.name
subprocess.check_call(
["gcloud", "auth", "activate-service-account",
"--key-file", gcloud_service_account_file.name]
)
subprocess.check_call(
["gcloud", "container", "clusters", "get-credentials",
self.cluster_name,
"--region", self.location,
"--project", self.project_id
],
env=custom_env
)
# Tell `GKEJobOperator` kubectl api instance where the config file is located
config.load_kube_config(kube_config_file.name)
return client
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment