本記事は、 Databricks - Qiita Advent Calendar 2024 - Qiitaシリーズ 2 の 25 日目の記事です。
モデルサービングをデプロイする際に使用しているNotebookテンプレートの運用をご紹介します。
モデルサービングはTerraformでデプロイすることも可能です。しかし、モデルの運用(再学習)などを考慮すると、Terraformで管理しているモデルのバージョンと実際に利用されているモデルのバージョンが一致しない場合があります。そこで、TerraformではNotebook Jobのみを管理し、Notebook内でモデルサービングのデプロイを行う運用を採用しています。
また、モデルサービングをデプロイする際には、モデルのロジック部分以外を共通化することでテンプレートとして管理しています。このテンプレートを利用することで、モデルの作成者はロジック部分のみを編集するだけで簡単にモデルサービングをデプロイできます。
Notebookテンプレート
以下のようなテンプレートを作成しています。「EDITABLE 」となっている箇所がモデル作成者側で修正する箇所になります。利用者はテンプレートをコピーし、「EDITABLE 」となっているモデルのロジック部分を修正します。「EDITABLE 」の前後にある「pre_hook」と「post_hook」は共通処理となります。後ほど紹介しますが、Notebookに渡されたパラメータに基づいてモデルやモデルサービングの登録や更新、削除などの処理をしています。「EDITABLE」となっている箇所も別ファイルとして管理した方がテンプレートはシンプルになりますが、Databricks上でモデルの実行結果等可視化して確認しやすいようこのような運用を採用しています。
# Databricks notebook source # MAGIC %run ./model_serving # COMMAND ---------- import json import mlflow mlflow.set_registry_uri("databricks-uc") # COMMAND ---------- params_string = dbutils.widgets.get("params") params = json.loads(params_string) print(f'params: {params}') # COMMAND ---------- # MAGIC %python # MAGIC pre_hook(params) # COMMAND ---------- ################ EDITABLE ################ ################ START ################ from sklearn import datasets from sklearn.ensemble import RandomForestClassifier model = params['model'] artifact_path = model['artifact_path'] endpoint = params['endpoint'] model_name = endpoint['config']['served_entities'][0]['entity_name'] experiment_path = model['experiment_path'] mlflow.set_experiment(experiment_path) with mlflow.start_run(): # Train a sklearn model on the iris dataset X, y = datasets.load_iris(return_X_y=True, as_frame=True) clf = RandomForestClassifier(max_depth=7) clf.fit(X, y) # Take the first row of the training dataset as the model input example. input_example = X.iloc[[0]] # Log the model and register it as a new version in UC. mlflow.sklearn.log_model( sk_model=clf, artifact_path=artifact_path, # The signature is automatically inferred from the input example and its predicted output. input_example=input_example, registered_model_name=model_name, ) ################ END ################ # COMMAND ---------- # MAGIC %python # MAGIC post_hook(params)
次にテンプレートで参照している「pre_hook」と「post_hook」について紹介します。以下のコードを見るとわかりますが、Notebookのパラメータに基づいてモデルやモデルサービングの作成や更新、削除等を行っています。「pre_hook」と「post_hook」の処理から見ていただくと雰囲気が掴めると思います。
# Databricks notebook source from mlflow.deployments import get_deploy_client from mlflow.tracking import MlflowClient import time from distutils.util import strtobool # COMMAND ---------- def pre_hook(params): model = params['model'] force_delete = bool(strtobool(params['force_delete'])) endpoint = params['endpoint'] endpoint_name = endpoint['endpoint_name'] config = endpoint['config'] model_name = config['served_entities'][0]['entity_name'] model_version = config['served_entities'][0]['entity_version'] print(f'codfig: {config}') print(f'force_delete: {force_delete}') if model_registered_exists(model_name) and force_delete: delete_registered_model(model_name) print('skip delete_registered_model') if endpoint_exists(endpoint_name) and force_delete: delete_model_serving_endpoint(endpoint_name) dbutils.notebook.exit("Model serving endpoint deleted. Exiting the notebook.") print('skip delete_model_serving_endpoint') if model_registered_exists(model_name) and endpoint_exists(endpoint_name) and model_version != 'latest': update_model_serving_endpoint(endpoint_name, config) wait_for_endpoint_ready(endpoint_name) dbutils.notebook.exit(f'Updated model endpoint version: {model_version}. Exiting the notebook') print('skip update_model_serving_endpoint') def post_hook(params): endpoint = params['endpoint'] endpoint_name = endpoint['endpoint_name'] config = endpoint['config'] model_name = config['served_entities'][0]['entity_name'] model_version = config['served_entities'][0]['entity_version'] print(f'codfig: {config}') if model_version == 'latest': model_version = get_model_registered_latest_version(model_name) config['served_entities'][0]['entity_version'] = model_version print(f'model version: {model_version}') # Check if the endpoint exists if endpoint_exists(endpoint_name): print(f"Endpoint '{endpoint_name}' exists.") wait_for_endpoint_ready(endpoint_name) print(f'update_model_serving_endpoint') update_model_serving_endpoint(endpoint_name, config) else: print(f"Endpoint '{endpoint_name}' does not exist.") print(f'create_model_serving_endpoint') create_model_serving_endpoint(endpoint_name, config) wait_for_endpoint_ready(endpoint_name) print('end') # Model Registerd def delete_registered_model(model_name): try: client = MlflowClient() client.delete_registered_model(name=model_name) print(f"Model '{model_name}' has been deleted.") except Exception as e: print(f"Error deleting model '{model_name}': {e}") raise def model_registered_exists(model_name): try: client = MlflowClient() registered_models = client.search_registered_models() return any(model.name == model_name for model in registered_models) except Exception as e: print(f"Error checking if model '{model_name}' is registered in Unity Catalog: {e}") return False def get_model_registered_latest_version(model_name): try: client = MlflowClient() print(f'model_name:{model_name}') client = MlflowClient() model_version_infos = client.search_model_versions("name = '%s'" % model_name) print(f'model_version_infos:{model_version_infos}') latest_version = max([int(model_version_info.version) for model_version_info in model_version_infos]) return latest_version except Exception as e: print(f"Error get_model_registered_latest_version: {e}") raise # Model serving def delete_model_serving_endpoint(endpoint_name): try: deploy_client = get_deploy_client("databricks") deploy_client.delete_endpoint(endpoint=endpoint_name) print(f"Model serving endpoint '{endpoint_name}' has been deleted.") except Exception as e: print(f"Error deleting model serving endpoint '{endpoint_name}': {e}") raise def create_model_serving_endpoint(name, config): try: deploy_client = get_deploy_client("databricks") deploy_client.create_endpoint(name=name,config=config) print(f"Model serving endpoint '{name}' has been created. config: {config}") except Exception as e: print(f"Error createing model serving endpoint '{name}': {e}") raise def update_model_serving_endpoint(endpoint_name, config): try: deploy_client = get_deploy_client("databricks") deploy_client.update_endpoint(endpoint=endpoint_name,config=config) print(f"Model serving endpoint '{endpoint_name}' has been updated. config: {config}") except Exception as e: print(f"Error updating model serving endpoint '{endpoint_name}': {e}") raise def endpoint_exists(endpoint_name): try: deploy_client = get_deploy_client("databricks") endpoints = deploy_client.list_endpoints() return any(endpoint['name'] == endpoint_name for endpoint in endpoints) except Exception as e: print(f"Error checking if endpoint exists '{endpoint_name}': {e}") return False def get_endpoint_status(endpoint_name): try: deploy_client = get_deploy_client("databricks") endpoint_status = deploy_client.get_endpoint(endpoint=endpoint_name) return endpoint_status except Exception as e: print(f"Error getting status for endpoint '{endpoint_name}': {e}") raise def wait_for_endpoint_ready(endpoint_name, timeout=1000, interval=30): start_time = time.time() while time.time() - start_time < timeout: try: # https://docs.databricks.com/api/workspace/servingendpoints/get status = get_endpoint_status(endpoint_name) if status['state']['ready'] == 'READY': print(f"Endpoint '{endpoint_name}' is ready.") if status['state']['config_update'] == 'NOT_UPDATING': print(f"Endpoint '{endpoint_name}' is currently NOT_UPDATING.") return else: print(f"Endpoint '{endpoint_name}' is currently being updated. Waiting...") except Exception as e: if "RESOURCE_CONFLICT" in str(e): print(f"Endpoint '{endpoint_name}' is currently being updated. Waiting...") else: raise time.sleep(interval) raise TimeoutError(f"Endpoint '{endpoint_name}' is not ready after {timeout} seconds.")
Databricks Job(Terraform)
Notebook JobはTerraformを使用してデプロイしており、パラメータに関する情報は以下のYAMLファイルに定義されています。このYAMLの設定に基づき、TerraformでNotebook Jobを動的にデプロイしています。
jobs: wf_test_model_endpoint_tmpl: name: wf_test_model_endpoint_tmpl domain_tag: test notebook_path: "notebook/model_serving/model_endpoint.tmpl" job_params: { "force_delete": "false", "model": { "experiment_path": "/Shared/common_model_experiments/sample", "artifact_path": "model", }, "endpoint": { "endpoint_name": "workspace-model-endpoint", "config": { "served_entities": [ { "name": "iris_model_serving", "entity_name": "${Env}_data_science.<project>.sample_model", "entity_version": "latest", # latest or version number for rollback "workload_size": "Small", "scale_to_zero_enabled": "true", }, ], "auto_capture_config": { "catalog_name": "${Env}_catalog", "schema_name": "sample_schema, "enabled": "true" } }, }, } clusters: { spark_version: "15.3.x-cpu-ml-scala2.12", node_type_id: "i3.2xlarge", driver_node_type_id: "i3.2xlarge", autoscale: { min_workers: 2, max_workers: 10 }, } template_path: "../jobs/template/model_serving_template.json" git_url: "https://git-codecommit.ap-northeast-1.amazonaws.com/v1/repos/<repository_name>"
モデルサービング用に作成したNotebook Jobのテンプレートは以下のとおりです。このテンプレートでは、YAMLで定義されたパラメータを活用して動的にJobの定義を構築しています。
{ "name": "${name}", "email_notifications": { "no_alert_for_skipped_runs": false }, "notification_settings": { "no_alert_for_canceled_runs": true }, "webhook_notifications": {}, "timeout_seconds": 0, "max_concurrent_runs": 1, "tags": { "product": "${domain_tag}" }, "parameters": [ { "name": "params", "default": "${params}" } ], "job_clusters": [ { "job_cluster_key": "job_cluster_key_${env}", "new_cluster": { "spark_version": "${spark_version}", "node_type_id": "${node_type_id}", "driver_node_type_id": "${driver_node_type_id}", "policy_id": "${policy_id}", "autoscale": { "min_workers": "${min_workers}", "max_workers": "${max_workers}" }, "aws_attributes": { "first_on_demand": "${first_on_demand}" } } } ], "tasks": [ { "task_key": "deploy_model_serving", "max_retries": 0, "notebook_task": { "notebook_path": "${notebook_path}", "source": "GIT" }, "job_cluster_key": "job_cluster_key_${env}", "libraries": [ { "pypi": { "package": "mlflow-skinny[databricks]>=2.5.0" } } ] } ], "git_source": { "git_url": "${git_url}", "git_provider": "awscodecommit", "git_tag": "t_${env}" }, "format": "MULTI_TASK" }
動的に構築されたJobの情報は、以下のTerraformで定義されたJobに渡されています。この仕組みでは、モデルサービング以外のジョブリソースも含めて共通の定義を活用し、効率的な管理を実現しています。各ジョブの差分はlocalsで吸収し、local.job-association-mapを用いて全てのジョブリソース情報を一元的に受け取る構造となっています。
resource "databricks_job" "job" { for_each = local.job-association-map depends_on = [ databricks_cluster.shared_dbx_cluster ] name = each.value.name timeout_seconds = each.value.timeout_seconds max_concurrent_runs = each.value.max_concurrent_runs git_source { url = lookup(each.value, "git_source", null) != null ? each.value.git_source.git_url : null provider = lookup(each.value, "git_source", null) != null ? each.value.git_source.git_provider : null tag = lookup(each.value, "git_source", null) != null ? each.value.git_source.git_tag : null } tags = { product = lookup(each.value, "tags", null) != null ? each.value.tags.product : local.tags.product } email_notifications { no_alert_for_skipped_runs = lookup(each.value.email_notifications, "no_alert_for_skipped_runs", null) != null ? each.value.email_notifications.no_alert_for_skipped_runs : null on_start = lookup(each.value.email_notifications, "on_start", []) != [] ? each.value.email_notifications.on_start : [] on_success = lookup(each.value.email_notifications, "on_success", []) != [] ? each.value.email_notifications.on_success : [] on_failure = lookup(each.value.email_notifications, "on_failure", []) != [] ? each.value.email_notifications.on_failure : local.on_failure } format = each.value.format dynamic "trigger" { for_each = { for key, val in each.value : key => val if key == "trigger" && val != null } content { pause_status = trigger.value.pause_status file_arrival { url = trigger.value.file_arrival.url } } } # use existing cluster instead of new_cluster.This will be used for IDBCDB,To import existing resources. dynamic "job_cluster" { for_each = { for key, val in each.value.job_clusters : key => val if lookup(val, "new_cluster", null) != null } content { job_cluster_key = each.value.job_clusters[0].job_cluster_key new_cluster { spark_version = lookup(each.value.job_clusters[0].new_cluster, "spark_version", null) != null ? each.value.job_clusters[0].new_cluster.spark_version : local.clusters.spark_version node_type_id = lookup(each.value.job_clusters[0].new_cluster, "node_type_id", null) != null ? each.value.job_clusters[0].new_cluster.node_type_id : local.clusters.node_type_id driver_node_type_id = lookup(each.value.job_clusters[0].new_cluster, "driver_node_type_id", null) != null ? each.value.job_clusters[0].new_cluster.driver_node_type_id : local.clusters.driver_node_type_id policy_id = lookup(each.value.job_clusters[0].new_cluster, "policy_id", null) != null ? each.value.job_clusters[0].new_cluster.policy_id : local.clusters.policy_id runtime_engine = lookup(each.value.job_clusters[0].new_cluster, "runtime_engine", null) != null ? each.value.job_clusters[0].new_cluster.runtime_engine : local.clusters.runtime_engine spark_conf = lookup(each.value.job_clusters[0].new_cluster, "spark_conf", null) != null ? each.value.job_clusters[0].new_cluster.spark_conf : null autoscale { min_workers = lookup(each.value.job_clusters[0].new_cluster.autoscale, "min_workers", null) != null ? each.value.job_clusters[0].new_cluster.autoscale.min_workers : local.clusters.autoscale.min_workers max_workers = lookup(each.value.job_clusters[0].new_cluster.autoscale, "max_workers", null) != null ? each.value.job_clusters[0].new_cluster.autoscale.max_workers : local.clusters.autoscale.max_workers } aws_attributes { first_on_demand = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "first_on_demand", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.first_on_demand : local.clusters.aws_attributes.first_on_demand availability = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "availability", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.availability : local.clusters.aws_attributes.availability instance_profile_arn = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "instance_profile_arn", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.instance_profile_arn : local.clusters.aws_attributes.instance_profile_arn zone_id = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "zone_id", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.zone_id : local.clusters.aws_attributes.zone_id spot_bid_price_percent = lookup(each.value.job_clusters[0].new_cluster.aws_attributes, "spot_bid_price_percent", null) != null ? each.value.job_clusters[0].new_cluster.aws_attributes.spot_bid_price_percent : local.clusters.aws_attributes.spot_bid_price_percent } custom_tags = each.value.job_clusters[0].new_cluster.custom_tags } } } dynamic "notification_settings" { for_each = { for key, val in each.value : key => val if key == "notification_settings" && val != null } content { no_alert_for_skipped_runs = lookup(notification_settings.value, "no_alert_for_skipped_runs", null) != null ? notification_settings.value.no_alert_for_skipped_runs : null no_alert_for_canceled_runs = lookup(notification_settings.value, "no_alert_for_canceled_runs", null) != null ? notification_settings.value.no_alert_for_canceled_runs : null } } dynamic "schedule" { for_each = { for key, val in each.value : key => val if key == "schedule" && val != null } content { pause_status = schedule.value.pause_status quartz_cron_expression = schedule.value.quartz_cron_expression timezone_id = schedule.value.timezone_id } } dynamic "parameter" { for_each = contains(keys(each.value), "parameter") ? each.value["parameter"] : [] content { default = parameter.value.default name = parameter.value.name } } dynamic "queue" { for_each = { for key, val in each.value : key => val if key == "queue" && val != {} } content { # enabled = lookup(queue.value, "enabled", null) != null ? queue.value.enabled : false enabled = queue.value.enabled } } dynamic "task" { for_each = each.value.tasks content { task_key = task.value.task_key job_cluster_key = lookup(task.value, "job_cluster_key", null) != null ? task.value.job_cluster_key : null existing_cluster_id = lookup(task.value, "existing_cluster_id", null) != null ? task.value.existing_cluster_id : null max_retries = contains(keys(task.value), "max_retries") ? task.value["max_retries"] : local.max_retries min_retry_interval_millis = contains(keys(task.value), "min_retry_interval_millis") ? task.value["min_retry_interval_millis"] : local.min_retry_interval_millis run_if = lookup(task.value, "run_if", null) != null ? task.value.run_if : null dynamic "notebook_task" { for_each = { for key, val in task.value : key => val if key == "notebook_task" } content { notebook_path = notebook_task.value.notebook_path base_parameters = lookup(notebook_task.value, "base_parameters", null) != null ? notebook_task.value.base_parameters : {} source = notebook_task.value.source } } dynamic "depends_on" { for_each = contains(keys(task.value), "depends_on") ? task.value["depends_on"] : [] content { task_key = depends_on.value.task_key outcome = lookup(depends_on.value, "outcome", null) != null ? depends_on.value.outcome : null } } dynamic "dbt_task" { for_each = { for key, val in task.value : key => val if key == "dbt_task" } content { project_directory = task.value.dbt_task.project_directory commands = task.value.dbt_task.commands schema = task.value.dbt_task.schema warehouse_id = task.value.dbt_task.warehouse_id catalog = task.value.dbt_task.catalog } } dynamic "spark_python_task" { for_each = { for key, val in task.value : key => val if key == "spark_python_task" } content { parameters = task.value.spark_python_task.parameters python_file = task.value.spark_python_task.python_file source = task.value.spark_python_task.source } } dynamic "condition_task" { for_each = { for key, val in task.value : key => val if key == "condition_task" } content { left = task.value.condition_task.left op = task.value.condition_task.op right = task.value.condition_task.right } } dynamic "library" { for_each = contains(keys(task.value), "libraries") ? task.value["libraries"] : [] content { pypi { package = task.value.libraries[0].pypi.package } } } timeout_seconds = lookup(task.value, "timeout_seconds", null) != null ? task.value.timeout_seconds : null dynamic "email_notifications" { for_each = { for key, val in task.value : key => val if key == "email_notifications" && val != {} } content { on_success = lookup(email_notifications.value, "on_success", null) != null ? email_notifications.value.on_success : null on_start = lookup(email_notifications.value, "on_start", null) != null ? email_notifications.value.on_start : null on_failure = lookup(email_notifications.value, "on_failure", null) != null ? email_notifications.value.on_failure : null } } dynamic "notification_settings" { for_each = { for key, val in task.value : key => val if key == "notification_settings" && val != {} } content { alert_on_last_attempt = lookup(notification_settings.value, "alert_on_last_attempt", null) != null ? notification_settings.value.alert_on_last_attempt : false no_alert_for_canceled_runs = lookup(notification_settings.value, "no_alert_for_canceled_runs", null) != null ? notification_settings.value.no_alert_for_canceled_runs : false no_alert_for_skipped_runs = lookup(notification_settings.value, "no_alert_for_skipped_runs", null) != null ? notification_settings.value.no_alert_for_skipped_runs : false } } } } }
まとめ
モデルサービングのデプロイはモデルの再学習等を考慮して、Terraformで直接デプロイするのではなく、Notebook Jobを活用しています。Notebookはロジック部分を除き共通化できるので、テンプレート化して利用者に提供しています。利用者はテンプレートを活用することで、モデルを作成したのち高速にデプロイできる環境を利用できます。