case-kの備忘録

日々の備忘録です。データ分析とか基盤系に興味あります。

Databricksでモデルサービングを迅速にデプロイするNotebook運用

本記事は、 Databricks - Qiita Advent Calendar 2024 - Qiitaシリーズ 2 の 25 日目の記事です。

モデルサービングをデプロイする際に使用しているNotebookテンプレートの運用をご紹介します。

モデルサービングはTerraformでデプロイすることも可能です。しかし、モデルの運用(再学習)などを考慮すると、Terraformで管理しているモデルのバージョンと実際に利用されているモデルのバージョンが一致しない場合があります。そこで、TerraformではNotebook Jobのみを管理し、Notebook内でモデルサービングのデプロイを行う運用を採用しています。

また、モデルサービングをデプロイする際には、モデルのロジック部分以外を共通化することでテンプレートとして管理しています。このテンプレートを利用することで、モデルの作成者はロジック部分のみを編集するだけで簡単にモデルサービングをデプロイできます。

docs.databricks.com

Notebookテンプレート

以下のようなテンプレートを作成しています。「EDITABLE 」となっている箇所がモデル作成者側で修正する箇所になります。利用者はテンプレートをコピーし、「EDITABLE 」となっているモデルのロジック部分を修正します。「EDITABLE 」の前後にある「pre_hook」と「post_hook」は共通処理となります。後ほど紹介しますが、Notebookに渡されたパラメータに基づいてモデルやモデルサービングの登録や更新、削除などの処理をしています。「EDITABLE」となっている箇所も別ファイルとして管理した方がテンプレートはシンプルになりますが、Databricks上でモデルの実行結果等可視化して確認しやすいようこのような運用を採用しています。

# Databricks notebook source

# MAGIC %run ./model_serving

# COMMAND ----------
import json
import mlflow
mlflow.set_registry_uri("databricks-uc")

# COMMAND ----------

params_string = dbutils.widgets.get("params")
params = json.loads(params_string)
print(f'params: {params}')


# COMMAND ----------

# MAGIC %python
# MAGIC pre_hook(params)

# COMMAND ----------
################ EDITABLE ################ 
################ START ################ 
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier

model = params['model']
artifact_path = model['artifact_path']
endpoint = params['endpoint']
model_name = endpoint['config']['served_entities'][0]['entity_name']

experiment_path = model['experiment_path']
mlflow.set_experiment(experiment_path)

with mlflow.start_run():
    # Train a sklearn model on the iris dataset
    X, y = datasets.load_iris(return_X_y=True, as_frame=True)
    clf = RandomForestClassifier(max_depth=7)
    clf.fit(X, y)
    # Take the first row of the training dataset as the model input example.
    input_example = X.iloc[[0]]
    # Log the model and register it as a new version in UC.
    mlflow.sklearn.log_model(
        sk_model=clf,
        artifact_path=artifact_path,
        # The signature is automatically inferred from the input example and its predicted output.
        input_example=input_example,
        registered_model_name=model_name,
    )
################ END ################ 

# COMMAND ----------

# MAGIC %python
# MAGIC post_hook(params)

次にテンプレートで参照している「pre_hook」と「post_hook」について紹介します。以下のコードを見るとわかりますが、Notebookのパラメータに基づいてモデルやモデルサービングの作成や更新、削除等を行っています。「pre_hook」と「post_hook」の処理から見ていただくと雰囲気が掴めると思います。

# Databricks notebook source

from mlflow.deployments import get_deploy_client
from mlflow.tracking import MlflowClient
import time
from distutils.util import strtobool


# COMMAND ----------

def pre_hook(params):
    model = params['model']
    force_delete = bool(strtobool(params['force_delete']))
    endpoint = params['endpoint']
    endpoint_name = endpoint['endpoint_name']
    config = endpoint['config']
    model_name = config['served_entities'][0]['entity_name']
    model_version = config['served_entities'][0]['entity_version']
    print(f'codfig: {config}')
    print(f'force_delete: {force_delete}')
    if model_registered_exists(model_name) and force_delete:
        delete_registered_model(model_name)
    print('skip delete_registered_model')
    if endpoint_exists(endpoint_name) and force_delete:
        delete_model_serving_endpoint(endpoint_name)
        dbutils.notebook.exit("Model serving endpoint deleted. Exiting the notebook.")
    print('skip delete_model_serving_endpoint')
    if model_registered_exists(model_name) and endpoint_exists(endpoint_name) and model_version != 'latest':
        update_model_serving_endpoint(endpoint_name, config)
        wait_for_endpoint_ready(endpoint_name)
        dbutils.notebook.exit(f'Updated model endpoint version: {model_version}. Exiting the notebook')
    print('skip update_model_serving_endpoint')

def post_hook(params):
    endpoint = params['endpoint']
    endpoint_name = endpoint['endpoint_name']
    config = endpoint['config']
    model_name = config['served_entities'][0]['entity_name']
    model_version = config['served_entities'][0]['entity_version']
    print(f'codfig: {config}')

    if model_version == 'latest':
        model_version = get_model_registered_latest_version(model_name)
        config['served_entities'][0]['entity_version'] = model_version
    print(f'model version: {model_version}')
    # Check if the endpoint exists
    if endpoint_exists(endpoint_name):
        print(f"Endpoint '{endpoint_name}' exists.")
        wait_for_endpoint_ready(endpoint_name)
        print(f'update_model_serving_endpoint')
        update_model_serving_endpoint(endpoint_name, config)
        
    else:
        print(f"Endpoint '{endpoint_name}' does not exist.")
        print(f'create_model_serving_endpoint')
        create_model_serving_endpoint(endpoint_name, config)

    wait_for_endpoint_ready(endpoint_name)
    print('end')

# Model Registerd
def delete_registered_model(model_name):
    try:
        client = MlflowClient()
        client.delete_registered_model(name=model_name)
        print(f"Model '{model_name}' has been deleted.")
    except Exception as e:
        print(f"Error deleting model '{model_name}': {e}")
        raise

    
def model_registered_exists(model_name):
    try:
        client = MlflowClient()
        registered_models = client.search_registered_models()
        return any(model.name == model_name for model in registered_models)
    except Exception as e:
        print(f"Error checking if model '{model_name}' is registered in Unity Catalog: {e}")
        return False

def get_model_registered_latest_version(model_name):
    try:
        client = MlflowClient()
        print(f'model_name:{model_name}')
        client = MlflowClient()
        model_version_infos = client.search_model_versions("name = '%s'" % model_name)
        print(f'model_version_infos:{model_version_infos}')
        latest_version = max([int(model_version_info.version) for model_version_info in model_version_infos])
        return latest_version
    except Exception as e:
        print(f"Error get_model_registered_latest_version: {e}")
        raise


# Model serving
def delete_model_serving_endpoint(endpoint_name):
    try:
        deploy_client = get_deploy_client("databricks")
        deploy_client.delete_endpoint(endpoint=endpoint_name)
        print(f"Model serving endpoint '{endpoint_name}' has been deleted.")
    except Exception as e:
        print(f"Error deleting model serving endpoint '{endpoint_name}': {e}")
        raise

def create_model_serving_endpoint(name, config):
    try:
        deploy_client = get_deploy_client("databricks")
        deploy_client.create_endpoint(name=name,config=config)
        print(f"Model serving endpoint '{name}' has been created. config: {config}")
    except Exception as e:
        print(f"Error createing model serving endpoint '{name}': {e}")
        raise

def update_model_serving_endpoint(endpoint_name, config):
    try:
        deploy_client = get_deploy_client("databricks")
        deploy_client.update_endpoint(endpoint=endpoint_name,config=config)
        print(f"Model serving endpoint '{endpoint_name}' has been updated. config: {config}")
    except Exception as e:
        print(f"Error updating model serving endpoint '{endpoint_name}': {e}")
        raise

def endpoint_exists(endpoint_name):
    try:
        deploy_client = get_deploy_client("databricks")
        endpoints = deploy_client.list_endpoints()
        return any(endpoint['name'] == endpoint_name for endpoint in endpoints)
    except Exception as e:
        print(f"Error checking if endpoint exists '{endpoint_name}': {e}")
        return False

def get_endpoint_status(endpoint_name):
    try:
        deploy_client = get_deploy_client("databricks")
        endpoint_status = deploy_client.get_endpoint(endpoint=endpoint_name)
        return endpoint_status
    except Exception as e:
        print(f"Error getting status for endpoint '{endpoint_name}': {e}")
        raise

def wait_for_endpoint_ready(endpoint_name, timeout=1000, interval=30):
    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            # https://docs.databricks.com/api/workspace/servingendpoints/get
            status = get_endpoint_status(endpoint_name)
            if status['state']['ready'] == 'READY':
                print(f"Endpoint '{endpoint_name}' is ready.")

            if status['state']['config_update'] == 'NOT_UPDATING':
                print(f"Endpoint '{endpoint_name}' is currently NOT_UPDATING.")
                return
            else:
                print(f"Endpoint '{endpoint_name}' is currently being updated. Waiting...")
        except Exception as e:
            if "RESOURCE_CONFLICT" in str(e):
                print(f"Endpoint '{endpoint_name}' is currently being updated. Waiting...")
            else:
                raise
        time.sleep(interval)
    raise TimeoutError(f"Endpoint '{endpoint_name}' is not ready after {timeout} seconds.")

Databricks Job(Terraform)

Notebook JobはTerraformを使用してデプロイしており、パラメータに関する情報は以下のYAMLファイルに定義されています。このYAMLの設定に基づき、TerraformでNotebook Jobを動的にデプロイしています。

jobs:
  wf_test_model_endpoint_tmpl:
    name: wf_test_model_endpoint_tmpl
    domain_tag: test
    notebook_path: "notebook/model_serving/model_endpoint.tmpl"
    job_params:
      {
        "force_delete": "false",
        "model":
          {
            "experiment_path": "/Shared/common_model_experiments/sample",
            "artifact_path": "model",
          },
        "endpoint":
          {
            "endpoint_name": "workspace-model-endpoint",
            "config": { "served_entities": [
                    {
                      "name": "iris_model_serving",
                      "entity_name": "${Env}_data_science.<project>.sample_model",
                      "entity_version": "latest", # latest or version number for rollback
                      "workload_size": "Small",
                      "scale_to_zero_enabled": "true",
                    },
                  ], "auto_capture_config": { "catalog_name": "${Env}_catalog", "schema_name": "sample_schema, "enabled": "true" } },
          },
      }
    clusters:
      {
        spark_version: "15.3.x-cpu-ml-scala2.12",
        node_type_id: "i3.2xlarge",
        driver_node_type_id: "i3.2xlarge",
        autoscale: { min_workers: 2, max_workers: 10 },
      }
    template_path: "../jobs/template/model_serving_template.json"
    git_url: "https://git-codecommit.ap-northeast-1.amazonaws.com/v1/repos/<repository_name>"

モデルサービング用に作成したNotebook Jobのテンプレートは以下のとおりです。このテンプレートでは、YAMLで定義されたパラメータを活用して動的にJobの定義を構築しています。

{
    "name": "${name}",
    "email_notifications": {
        "no_alert_for_skipped_runs": false
    },
    "notification_settings": {
        "no_alert_for_canceled_runs": true
    },
    "webhook_notifications": {},
    "timeout_seconds": 0,
    "max_concurrent_runs": 1,
    "tags": {
        "product": "${domain_tag}"
    },
    "parameters": [
        {
            "name": "params",
            "default": "${params}"
        }
    ],
    "job_clusters": [
        {
            "job_cluster_key": "job_cluster_key_${env}",
            "new_cluster": {
                "spark_version": "${spark_version}",
                "node_type_id": "${node_type_id}",
                "driver_node_type_id": "${driver_node_type_id}",
                "policy_id": "${policy_id}",
                "autoscale": {
                    "min_workers": "${min_workers}",
                    "max_workers": "${max_workers}"
                },
                "aws_attributes": {
                    "first_on_demand": "${first_on_demand}"
                }
            }
        }
    ],
    "tasks": [
        {
            "task_key": "deploy_model_serving",
            "max_retries": 0,
            "notebook_task": {
                "notebook_path": "${notebook_path}",
                "source": "GIT"
            },
            "job_cluster_key": "job_cluster_key_${env}",
            "libraries": [
                {
                    "pypi": {
                        "package": "mlflow-skinny[databricks]>=2.5.0"
                    }
                }
            ]
        }
    ],
    "git_source": {
        "git_url": "${git_url}",
        "git_provider": "awscodecommit",
        "git_tag": "t_${env}"
    },
    "format": "MULTI_TASK"
}

動的に構築されたJobの情報は、以下のTerraformで定義されたJobに渡されています。この仕組みでは、モデルサービング以外のジョブリソースも含めて共通の定義を活用し、効率的な管理を実現しています。各ジョブの差分はlocalsで吸収し、local.job-association-mapを用いて全てのジョブリソース情報を一元的に受け取る構造となっています。

resource "databricks_job" "job" {
  for_each = local.job-association-map
  depends_on = [
    databricks_cluster.shared_dbx_cluster
  ]
  name                = each.value.name
  timeout_seconds     = each.value.timeout_seconds
  max_concurrent_runs = each.value.max_concurrent_runs
  git_source {
    url      = lookup(each.value, "git_source", null) != null ? each.value.git_source.git_url : null
    provider = lookup(each.value, "git_source", null) != null ? each.value.git_source.git_provider : null
    tag      = lookup(each.value, "git_source", null) != null ? each.value.git_source.git_tag : null
  }
  tags = {
    product = lookup(each.value, "tags", null) != null ? each.value.tags.product : local.tags.product
  }
  email_notifications {
    no_alert_for_skipped_runs = lookup(each.value.email_notifications, "no_alert_for_skipped_runs", null) != null ? each.value.email_notifications.no_alert_for_skipped_runs : null
    on_start                  = lookup(each.value.email_notifications, "on_start", []) != [] ? each.value.email_notifications.on_start : []
    on_success                = lookup(each.value.email_notifications, "on_success", []) != [] ? each.value.email_notifications.on_success : []
    on_failure                = lookup(each.value.email_notifications, "on_failure", []) != [] ? each.value.email_notifications.on_failure : local.on_failure
  }
  format = each.value.format

  dynamic "trigger" {
    for_each = { for key, val in each.value :
    key => val if key == "trigger" && val != null }
    content {
      pause_status = trigger.value.pause_status
      file_arrival {
        url = trigger.value.file_arrival.url
      }
    }
  }

  # use existing cluster instead of new_cluster.This will be used for IDBCDB,To import existing resources. 
  dynamic "job_cluster" {
    for_each = { for key, val in each.value.job_clusters :
    key => val if lookup(val, "new_cluster", null) != null }
    content {
      job_cluster_key = each.value.job_clusters[0].job_cluster_key
      new_cluster {
        spark_version       = lookup(each.value.job_clusters[0].new_cluster, "spark_version", null) != null ? each.value.job_clusters[0].new_cluster.spark_version : local.clusters.spark_version
        node_type_id        = lookup(each.value.job_clusters[0].new_cluster, "node_type_id", null) != null ? each.value.job_clusters[0].new_cluster.node_type_id : local.clusters.node_type_id
        driver_node_type_id = lookup(each.value.job_clusters[0].new_cluster, "driver_node_type_id", null) != null ? each.value.job_clusters[0].new_cluster.driver_node_type_id : local.clusters.driver_node_type_id
        policy_id           = lookup(each.value.job_clusters[0].new_cluster, "policy_id", null) != null ? each.value.job_clusters[0].new_cluster.policy_id : local.clusters.policy_id
        runtime_engine      = lookup(each.value.job_clusters[0].new_cluster, "runtime_engine", null) != null ? each.value.job_clusters[0].new_cluster.runtime_engine : local.clusters.runtime_engine
        spark_conf          = lookup(each.value.job_clusters[0].new_cluster, "spark_conf", null) != null ? each.value.job_clusters[0].new_cluster.spark_conf : null
        autoscale {
          min_workers = lookup(each.value.job_clusters[0].new_cluster.autoscale, "min_workers", null) != null ? each.value.job_clusters[0].new_cluster.autoscale.min_workers : local.clusters.autoscale.min_workers
          max_workers = lookup(each.value.job_clusters[0].new_cluster.autoscale, "max_workers", null) != null ? each.value.job_clusters[0].new_cluster.autoscale.max_workers : local.clusters.autoscale.max_workers
        }
        aws_attributes {
          first_on_demand        = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "first_on_demand", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.first_on_demand : local.clusters.aws_attributes.first_on_demand
          availability           = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "availability", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.availability : local.clusters.aws_attributes.availability
          instance_profile_arn   = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "instance_profile_arn", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.instance_profile_arn : local.clusters.aws_attributes.instance_profile_arn
          zone_id                = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "zone_id", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.zone_id : local.clusters.aws_attributes.zone_id
          spot_bid_price_percent = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "spot_bid_price_percent", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.spot_bid_price_percent : local.clusters.aws_attributes.spot_bid_price_percent
        }
        custom_tags = each.value.job_clusters[0].new_cluster.custom_tags
      }
    }
  }

  dynamic "notification_settings" {
    for_each = { for key, val in each.value :
    key => val if key == "notification_settings" && val != null }
    content {
      no_alert_for_skipped_runs  = lookup(notification_settings.value, "no_alert_for_skipped_runs", null) != null ? notification_settings.value.no_alert_for_skipped_runs : null
      no_alert_for_canceled_runs = lookup(notification_settings.value, "no_alert_for_canceled_runs", null) != null ? notification_settings.value.no_alert_for_canceled_runs : null
    }
  }

  dynamic "schedule" {
    for_each = { for key, val in each.value :
    key => val if key == "schedule" && val != null }
    content {
      pause_status           = schedule.value.pause_status
      quartz_cron_expression = schedule.value.quartz_cron_expression
      timezone_id            = schedule.value.timezone_id
    }
  }

  dynamic "parameter" {
    for_each = contains(keys(each.value), "parameter") ? each.value["parameter"] : []
    content {
      default = parameter.value.default
      name    = parameter.value.name
    }
  }

  dynamic "queue" {
    for_each = { for key, val in each.value : key => val if key == "queue" && val != {} }
    content {
      # enabled = lookup(queue.value, "enabled", null) != null ? queue.value.enabled : false
      enabled = queue.value.enabled
    }
  }

  dynamic "task" {
    for_each = each.value.tasks
    content {
      task_key                  = task.value.task_key
      job_cluster_key           = lookup(task.value, "job_cluster_key", null) != null ? task.value.job_cluster_key : null
      existing_cluster_id       = lookup(task.value, "existing_cluster_id", null) != null ? task.value.existing_cluster_id : null
      max_retries               = contains(keys(task.value), "max_retries") ? task.value["max_retries"] : local.max_retries
      min_retry_interval_millis = contains(keys(task.value), "min_retry_interval_millis") ? task.value["min_retry_interval_millis"] : local.min_retry_interval_millis
      run_if                    = lookup(task.value, "run_if", null) != null ? task.value.run_if : null
      dynamic "notebook_task" {
        for_each = { for key, val in task.value :
        key => val if key == "notebook_task" }
        content {
          notebook_path   = notebook_task.value.notebook_path
          base_parameters = lookup(notebook_task.value, "base_parameters", null) != null ? notebook_task.value.base_parameters : {}
          source          = notebook_task.value.source
        }
      }

      dynamic "depends_on" {
        for_each = contains(keys(task.value), "depends_on") ? task.value["depends_on"] : []

        content {
          task_key = depends_on.value.task_key
          outcome  = lookup(depends_on.value, "outcome", null) != null ? depends_on.value.outcome : null
        }
      }

      dynamic "dbt_task" {
        for_each = { for key, val in task.value :
        key => val if key == "dbt_task" }
        content {
          project_directory = task.value.dbt_task.project_directory
          commands          = task.value.dbt_task.commands
          schema            = task.value.dbt_task.schema
          warehouse_id      = task.value.dbt_task.warehouse_id
          catalog           = task.value.dbt_task.catalog
        }
      }

      dynamic "spark_python_task" {
        for_each = { for key, val in task.value :
        key => val if key == "spark_python_task" }
        content {
          parameters  = task.value.spark_python_task.parameters
          python_file = task.value.spark_python_task.python_file
          source      = task.value.spark_python_task.source
        }
      }

      dynamic "condition_task" {
        for_each = { for key, val in task.value :
        key => val if key == "condition_task" }
        content {
          left  = task.value.condition_task.left
          op    = task.value.condition_task.op
          right = task.value.condition_task.right
        }
      }
      dynamic "library" {
        for_each = contains(keys(task.value), "libraries") ? task.value["libraries"] : []
        content {
          pypi { package = task.value.libraries[0].pypi.package }
        }
      }
      timeout_seconds = lookup(task.value, "timeout_seconds", null) != null ? task.value.timeout_seconds : null

      dynamic "email_notifications" {
        for_each = { for key, val in task.value :
        key => val if key == "email_notifications" && val != {} }
        content {
          on_success = lookup(email_notifications.value, "on_success", null) != null ? email_notifications.value.on_success : null
          on_start   = lookup(email_notifications.value, "on_start", null) != null ? email_notifications.value.on_start : null
          on_failure = lookup(email_notifications.value, "on_failure", null) != null ? email_notifications.value.on_failure : null
        }
      }

      dynamic "notification_settings" {
        for_each = { for key, val in task.value :
        key => val if key == "notification_settings" && val != {} }
        content {
          alert_on_last_attempt      = lookup(notification_settings.value, "alert_on_last_attempt", null) != null ? notification_settings.value.alert_on_last_attempt : false
          no_alert_for_canceled_runs = lookup(notification_settings.value, "no_alert_for_canceled_runs", null) != null ? notification_settings.value.no_alert_for_canceled_runs : false
          no_alert_for_skipped_runs  = lookup(notification_settings.value, "no_alert_for_skipped_runs", null) != null ? notification_settings.value.no_alert_for_skipped_runs : false
        }
      }
    }
  }
}

まとめ

モデルサービングのデプロイはモデルの再学習等を考慮して、Terraformで直接デプロイするのではなく、Notebook Jobを活用しています。Notebookはロジック部分を除き共通化できるので、テンプレート化して利用者に提供しています。利用者はテンプレートを活用することで、モデルを作成したのち高速にデプロイできる環境を利用できます。