Model Explainability

During the training process, machine learning models infer relationships between inputs and outputs, rather than requiring that these relationships be stated explicitly up front. This allows ML techniques to tackle complicated scenarios involving many variables without extensive setup, particularly where the causal factors of a particular outcome are complex or unclear, but the resulting model can be something of a black box. If a model underperforms, it can be difficult to understand why, and furthermore how to improve its performance. The black box model can also conceal implicit biases and fail to establish clear reasons for decisions. Industries that have regulations around trustworthy systems, like finance and healthcare, might require stronger evidence that the model is producing the correct results for the right reasons.

To help address such concerns, the Snowflake Model Registry includes an explainability function based on Shapley values. Shapley values are a way to attribute the output of a machine learning model to its input features. By considering all possible combinations of features, Shapley values measure the average marginal contribution of each feature to the model’s prediction. This approach ensures fairness in attributing importance and provides a solid foundation for understanding complex models. While computationally intensive, the insights gained from Shapley values are invaluable for model interpretability and debugging.

For example, if a house pricing model has been trained on size, location, number of bedrooms, and whether pets are allowed, the model might predict a price of $250,000 for a house that is 2000 square feet, is on the beach, has three bedrooms, and does not allow pets. Each of these feature values might contribute to the final model prediction as shown in the following table.

Feature

Value

Contribution

Size

2000

$100,000

Location

Beachside

$100,000

Bedrooms

3

$100,000

Pets

No

-$50,000

Note that, as illustrated by the Pets line, contributions can be negative. In this example, it is less desirable to live in a house where pets are not allowed, so that feature value’s contribution is -$50,000.

This preview release supports the following Python-native model packages.

  • XGBoost

  • CatBoost

  • LightGBM

Snowpark ML modeling classes from snowflake.ml.modeling are not supported in this release.

Explainability is available by default for models logged using Snowpark Python 1.6.0 and later. The implementation uses the SHAP library.

Retrieving Explainability Values

Models with explainability have a method named explain that returns the Shapley values for the model’s features.

Because Shapley values are explanations of predictions made from specific inputs, you must pass input data to explain to generate the predictions to be explained.

The Snowflake model version object will have a method called explain, and you call it using ModelVersion.run in Python.

reg = Registry(...)
mv = reg.get_model("Explainable_Catboost_Model").default
explanations = mv.run(input_data, function_name="explain")
Copy

The following is an example of retrieving the explanation in SQL.

WITH MV_ALIAS AS MODEL DATABASE.SCHEMA.DIAMOND_CATBOOST_MODEL VERSION EXPLAIN_V0
SELECT *,
      FROM DATABASE.SCHEMA.DIAMOND_DATA,
          TABLE(MV_ALIAS!EXPLAIN(CUT, COLOR, CLARITY, CARAT, DEPTH, TABLE_PCT, X, Y, Z));
Copy

Important

Explainability uses version 0.42.1 of the SHAP library, which is incompatible with the latest XGBoost version (2.1.1) supported by Snowflake. If you receive the error UnicodeDecodeError: 'utf-8' codec can't decode byte, please downgrade the XGBoost version to 2.0.3 and log the model with the relax_version option set to False, as shown here.

mv_new = reg.log_model(
    model,
    model_name="model_with_explain_enabled",
    version_name="explain_v0",
    conda_dependencies=["snowflake-ml-python"],
    sample_input_data = xs,
    options={"relax_version": False}
)
Copy

Adding Explainability to Existing Models

Models that were logged in the registry using a version of Snowpark ML older than 1.6.0 do not have the explainability feature. Since model versions are immutable, you must create a new model version to add explainability to an existing model. You can use ModelVersion.load to retrieve the Python object represeting the model’s implementation, then log that to the registry as a new model version. This approach is shown below.

Important

The Python environment into which you load the model must be exactly the same (that is, the same version of Python and of all libraries) as the environment where the model is deployed. For details, see Loading a model version.

mv_old = reg.get_model("model_without_explain_enabled").default
model = mv_old.load()
mv_new = reg.log_model(
    model,
    model_name="model_with_explain_enabled",
    version_name="explain_v0",
    conda_dependencies=["snowflake-ml-python"],
    sample_input_data = xs
)
Copy

Logging Models Without Explainability

Explainability is enabled by default. To log a model version in the registry without explainability, pass False for the enable_explainability option when logging the model, as shown here.

mv = reg.log_model(
    catboost_model,
    model_name="diamond_catboost_explain_enabled",
    version_name="explain_v0",
    conda_dependencies=["snowflake-ml-python"],
    sample_input_data = xs,
    options= {"enable_explainability": False}
)
Copy