Model Explainability¶
During the training process, machine learning models infer relationships between inputs and outputs, rather than requiring that these relationships be stated explicitly up front. This allows ML techniques to tackle complicated scenarios involving many variables without extensive setup, particularly where the causal factors of a particular outcome are complex or unclear, but the resulting model can be something of a black box. If a model underperforms, it can be difficult to understand why, and furthermore how to improve its performance. The black box model can also conceal implicit biases and fail to establish clear reasons for decisions. Industries that have regulations around trustworthy systems, like finance and healthcare, might require stronger evidence that the model is producing the correct results for the right reasons.
To help address such concerns, the Snowflake Model Registry includes an explainability function based on Shapley values. Shapley values are a way to attribute the output of a machine learning model to its input features. By considering all possible combinations of features, Shapley values measure the average marginal contribution of each feature to the model’s prediction. This approach ensures fairness in attributing importance and provides a solid foundation for understanding complex models. While computationally intensive, the insights gained from Shapley values are invaluable for model interpretability and debugging.
For example, if a house pricing model has been trained on size, location, number of bedrooms, and whether pets are allowed, the model might predict a price of $250,000 for a house that is 2000 square feet, is on the beach, has three bedrooms, and does not allow pets. Each of these feature values might contribute to the final model prediction as shown in the following table.
Feature |
Value |
Contribution |
---|---|---|
Size |
2000 |
$100,000 |
Location |
Beachside |
$100,000 |
Bedrooms |
3 |
$100,000 |
Pets |
No |
-$50,000 |
Note that, as illustrated by the Pets line, contributions can be negative. In this example, it is less desirable to live in a house where pets are not allowed, so that feature value’s contribution is -$50,000.
This preview release supports the following Python-native model packages.
XGBoost
CatBoost
LightGBM
Snowpark ML modeling classes from snowflake.ml.modeling
are not supported in this release.
Explainability is available by default for models logged using Snowpark Python 1.6.0 and later. The implementation uses the SHAP library.
Retreiving Explainability Values¶
Models with explainability have a method named explain
that returns the Shapley values for the model’s features.
Because Shapley values are explanations of predictions made from specific inputs, you must pass input data to explain
to
generate the predictions to be explained.
The Snowflake model version object will have a method called explain
, and you call it using ModelVersion.run
in Python.
reg = Registry(...)
mv = reg.get_model("Explainable_Catboost_Model").default
explanations = mv.run(input_data, function_name="explain")
The following is an example of retrieving the explanation in SQL.
WITH MV_ALIAS AS MODEL DATABASE.SCHEMA.DIAMOND_CATBOOST_MODEL VERSION EXPLAIN_V0
SELECT *,
FROM DATABASE.SCHEMA.DIAMOND_DATA,
TABLE(MV_ALIAS!EXPLAIN(CUT, COLOR, CLARITY, CARAT, DEPTH, TABLE_PCT, X, Y, Z));
Adding Explainability to Existing Models¶
Models that were logged in the registry using a version of Snowpark ML older than 1.6.0 do not have the explainability
feature. Since model versions are immutable, you must create a new model version to add explainability to an existing
model. You can use ModelVersion.load
to retrieve the Python object represeting the model’s implementation, then log
that to the registry as a new model version. This approach is shown below.
Important
The Python environment into which you load the model must be exactly the same (that is, the same version of Python and of all libraries) as the environment where the model is deployed. For details, see Loading a Model Version.
mv_old = reg.get_model("model_without_explain_enabled").default
model = mv_old.load()
mv_new = reg.log_model(
model,
model_name="model_with_explain_enabled",
version_name="explain_v0",
conda_dependencies=["snowflake-ml-python"],
sample_input_data = xs
)
Logging Models Without Explainability¶
Explainability is enabled by default. To log a model version in the registry without explainability, pass False
for
the enable_explainability
option when logging the model, as shown here.
mv = reg.log_model(
catboost_model,
model_name="diamond_catboost_explain_enabled",
version_name="explain_v0",
conda_dependencies=["snowflake-ml-python"],
sample_input_data = xs,
options= {"enable_explainability": False}
)