Multinomial logistic regression
Import the necessary packages:
using DataFrames
using Distributions
using MLJBase
using NNlib
using RDatasets
using Soss
using SossMLJ
using Statistics
In this example, we fit a Bayesian multinomial logistic regression model with the canonical link function.
Suppose that we are given a matrix of features X
and a column vector of labels y
. X
has n
rows and p
columns. y
has n
elements. We assume that our observation vector y
is a realization of a random variable Y
. We define μ
(mu) as the expected value of Y
, i.e. μ := E[Y]
. Our model comprises three components:
- The probability distribution of
Y
. We assume that eachYᵢ
follows a multinomial distribution withk
categories, meanμᵢ
, and one trial. - The systematic component, which consists of linear predictor
η
(eta), which we define asη := Xβ
, whereβ
is the column vector ofp
coefficients. - The link function
g
, which provides the following relationship:g(E[Y]) = g(μ) = η = Xβ
. It follows thatμ = g⁻¹(η)
, whereg⁻¹
denotes the inverse ofg
. In multinomial logistic regression, the canonical link function is the generalized logit function. The inverse of the generalized logit function is the softmax function. Therefore, when using the canonical link function,μ = g⁻¹(η) = softmax(η)
.
A multinomial distribution with one trial is equivalent to the categorical distribution. Therefore, the following two statements are equivalent:
Yᵢ
follows a multinomial distribution withk
categories, meanμᵢ
, and one trial.Yᵢ
follows a categorical distribution withk
categories and meanμᵢ
.
Observe that the logistic regression model is a special case of the multinomial logistic regression model where k = 2
.
In this model, the parameters that we want to estimate are the coefficients β
. We need to select prior distributions for these parameters. For each βᵢ
we choose a normal distribution with zero mean and unit variance. Here, βᵢ
denotes the i
th component of β
.
We define this model using the Soss probabilistic programming library:
m = @model X,pool begin
n = size(X, 1) # number of observations
p = size(X, 2) # number of features
k = length(pool.levels) # number of classes
β ~ Normal(0.0, 1.0) |> iid(p, k) # coefficients
η = X * β # linear predictor
μ = NNlib.softmax(η; dims=2) # μ = g⁻¹(η) = softmax(η)
y_dists = UnivariateFinite(pool.levels, μ; pool=pool) # `UnivariateFinite` is mathematically equivalent to `Categorical`
y ~ For(j -> y_dists[j], n) # `Yᵢ ~ UnivariateFinite(mean=μᵢ, categories=k)`, which is mathematically equivalent to `Yᵢ ~ Categorical(mean=μᵢ, categories=k)`
end;
Import the Iris flower data set:
iris = dataset("datasets", "iris");
Define our feature columns:
feature_columns = [
:PetalLength,
:PetalWidth,
:SepalLength,
:SepalWidth,
]
4-element Vector{Symbol}: :PetalLength :PetalWidth :SepalLength :SepalWidth
Define our label column:
label_column = :Species
:Species
Convert the Soss model into a SossMLJModel
:
model = SossMLJModel(;
model = m,
predictor = MLJBase.UnivariateFinite,
hyperparams = (pool=iris.Species.pool,),
infer = dynamicHMC,
response = :y,
);
Create an MLJ machine for fitting our model:
mach = MLJBase.machine(model, iris[!, feature_columns], iris[!, :Species])
[34mMachine{SossMLJModel{,…}} @679[39m trained 0 times. args: 1: [34mSource @270[39m ⏎ `ScientificTypes.Table{AbstractVector{ScientificTypes.Continuous}}` 2: [34mSource @882[39m ⏎ `AbstractVector{ScientificTypes.Multiclass{3}}`
Fit the machine. This may take several minutes.
MLJBase.fit!(mach)
[34mMachine{SossMLJModel{,…}} @679[39m trained 1 time. args: 1: [34mSource @270[39m ⏎ `ScientificTypes.Table{AbstractVector{ScientificTypes.Continuous}}` 2: [34mSource @882[39m ⏎ `AbstractVector{ScientificTypes.Multiclass{3}}`
Construct the joint posterior:
predictor_joint = MLJBase.predict_joint(mach, iris[!, feature_columns])
typeof(predictor_joint)
SossMLJ.SossMLJPredictor{SossMLJModel{MLJBase.UnivariateFinite, Soss.Model{NamedTuple{(:X, :pool), T} where T<:Tuple, TypeEncoding(begin k = length(pool.levels) p = size(X, 2) β ~ Normal(0.0, 1.0) |> iid(p, k) η = X * β μ = NNlib.softmax(η; dims = 2) y_dists = UnivariateFinite(pool.levels, μ; pool = pool) n = size(X, 1) y ~ For((j->begin y_dists[j] end), n) end), TypeEncoding(Main.ex-example-multinomial-logistic-regression)}, NamedTuple{(:pool,), Tuple{CategoricalArrays.CategoricalPool{String, UInt8, CategoricalArrays.CategoricalValue{String, UInt8}}}}, typeof(Soss.dynamicHMC), Symbol, typeof(SossMLJ.default_transform)}, Vector{NamedTuple{(:β,), Tuple{Matrix{Float64}}}}, Soss.Model{NamedTuple{(:X, :pool, :β), T} where T<:Tuple, TypeEncoding(begin η = X * β μ = NNlib.softmax(η; dims = 2) y_dists = UnivariateFinite(pool.levels, μ; pool = pool) n = size(X, 1) y ~ For((j->begin y_dists[j] end), n) end), TypeEncoding(Main.ex-example-multinomial-logistic-regression)}, NamedTuple{(:X, :pool), Tuple{Matrix{Float64}, CategoricalArrays.CategoricalPool{String, UInt8, CategoricalArrays.CategoricalValue{String, UInt8}}}}}
Draw a single sample from the joint posterior:
single_sample = rand(predictor_joint)
150-element Vector{CategoricalArrays.CategoricalValue{String, UInt8}}: "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" ⋮ "virginica" "virginica" "virginica" "virginica" "virginica" "virginica" "versicolor" "virginica" "virginica"
For each row in the dataset, construct the marginal posterior predictive distribution
predictor_marginal = MLJBase.predict(mach, iris[!, feature_columns])
150-element MLJBase.UnivariateFiniteVector{ScientificTypes.Multiclass{3}, String, UInt8, Float64}: UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.978, versicolor=>0.022, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.959, versicolor=>0.0412, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.973, versicolor=>0.0272, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.948, versicolor=>0.0518, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.983, versicolor=>0.017, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.981, versicolor=>0.0192, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.979, versicolor=>0.0208, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.971, versicolor=>0.0288, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.939, versicolor=>0.0606, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.948, versicolor=>0.0522, virginica=>0.0002) ⋮ UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.002, versicolor=>0.296, virginica=>0.702) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0, versicolor=>0.0758, virginica=>0.924) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0, versicolor=>0.055, virginica=>0.945) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0004, versicolor=>0.045, virginica=>0.955) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0008, versicolor=>0.161, virginica=>0.838) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0, versicolor=>0.181, virginica=>0.819) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0008, versicolor=>0.233, virginica=>0.767) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0006, versicolor=>0.0796, virginica=>0.92) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.001, versicolor=>0.172, virginica=>0.827)
predictor_marginal is a UnivariateFiniteVector
typeof(predictor_marginal)
MLJBase.UnivariateFiniteVector{ScientificTypes.Multiclass{3}, String, UInt8, Float64} (alias for MLJBase.UnivariateFiniteArray{ScientificTypes.Multiclass{3}, String, UInt8, Float64, 1})
predictor_marginal
has one element for each row in the data set
@show size(predictor_marginal); @show size(iris, 1);
size(predictor_marginal) = (150,) size(iris, 1) = 150
Use cross-validation to evaluate the model with respect to the Brier score:
evaluate!(mach, resampling=CV(; nfolds = 4, shuffle = true), measure=brier_score, operation=MLJBase.predict)
┌─────────────────┬───────────────┬───────────────────────────────────┐ │ _.measure │ _.measurement │ _.per_fold │ ├─────────────────┼───────────────┼───────────────────────────────────┤ │ BrierScore @626 │ -0.0959 │ [-0.054, -0.0689, -0.151, -0.109] │ └─────────────────┴───────────────┴───────────────────────────────────┘ _.per_observation = [[[-0.00111, -0.00391, ..., -0.0242], [-0.0198, -0.0397, ..., -0.00632], [-0.349, -0.0158, ..., -0.000177], [-0.029, -0.149, ..., -0.115]]] _.fitted_params_per_fold = [ … ] _.report_per_fold = [ … ]
Use cross-validation to evaluate the model with respect to accuracy:
evaluate!(mach, resampling=CV(; nfolds = 4, shuffle = true), measure=accuracy, operation=MLJBase.predict_mode)
┌───────────────┬───────────────┬────────────────────────────┐ │ _.measure │ _.measurement │ _.per_fold │ ├───────────────┼───────────────┼────────────────────────────┤ │ Accuracy @463 │ 0.96 │ [0.974, 0.921, 1.0, 0.946] │ └───────────────┴───────────────┴────────────────────────────┘ _.per_observation = [missing] _.fitted_params_per_fold = [ … ] _.report_per_fold = [ … ]
The cross-validated accuracy is greater than 90%, which is pretty good!
This page was generated using Literate.jl.