Multinomial logistic regression
Import the necessary packages:
using DataFrames
using Distributions
using MLJBase
using NNlib
using RDatasets
using Soss
using SossMLJ
using Statistics
In this example, we fit a Bayesian multinomial logistic regression model with the canonical link function.
Suppose that we are given a matrix of features X
and a column vector of labels y
. X
has n
rows and p
columns. y
has n
elements. We assume that our observation vector y
is a realization of a random variable Y
. We define μ
(mu) as the expected value of Y
, i.e. μ := E[Y]
. Our model comprises three components:
- The probability distribution of
Y
. We assume that eachYᵢ
follows a multinomial distribution withk
categories, meanμᵢ
, and one trial. - The systematic component, which consists of linear predictor
η
(eta), which we define asη := Xβ
, whereβ
is the column vector ofp
coefficients. - The link function
g
, which provides the following relationship:g(E[Y]) = g(μ) = η = Xβ
. It follows thatμ = g⁻¹(η)
, whereg⁻¹
denotes the inverse ofg
. In multinomial logistic regression, the canonical link function is the generalized logit function. The inverse of the generalized logit function is the softmax function. Therefore, when using the canonical link function,μ = g⁻¹(η) = softmax(η)
.
A multinomial distribution with one trial is equivalent to the categorical distribution. Therefore, the following two statements are equivalent:
Yᵢ
follows a multinomial distribution withk
categories, meanμᵢ
, and one trial.Yᵢ
follows a categorical distribution withk
categories and meanμᵢ
.
Observe that the logistic regression model is a special case of the multinomial logistic regression model where k = 2
.
In this model, the parameters that we want to estimate are the coefficients β
. We need to select prior distributions for these parameters. For each βᵢ
we choose a normal distribution with zero mean and unit variance. Here, βᵢ
denotes the i
th component of β
.
We define this model using the Soss probabilistic programming library:
m = @model X,pool begin
n = size(X, 1) # number of observations
p = size(X, 2) # number of features
k = length(pool.levels) # number of classes
β ~ Normal(0.0, 1.0) |> iid(p, k) # coefficients
η = X * β # linear predictor
μ = NNlib.softmax(η; dims=2) # μ = g⁻¹(η) = softmax(η)
y_dists = UnivariateFinite(pool.levels, μ; pool=pool) # `UnivariateFinite` is mathematically equivalent to `Categorical`
y ~ For(j -> y_dists[j], n) # `Yᵢ ~ UnivariateFinite(mean=μᵢ, categories=k)`, which is mathematically equivalent to `Yᵢ ~ Categorical(mean=μᵢ, categories=k)`
end;
Import the Iris flower data set:
iris = dataset("datasets", "iris");
Define our feature columns:
feature_columns = [
:PetalLength,
:PetalWidth,
:SepalLength,
:SepalWidth,
]
4-element Vector{Symbol}: :PetalLength :PetalWidth :SepalLength :SepalWidth
Define our label column:
label_column = :Species
:Species
Convert the Soss model into a SossMLJModel
:
model = SossMLJModel(;
model = m,
predictor = MLJBase.UnivariateFinite,
hyperparams = (pool=iris.Species.pool,),
infer = dynamicHMC,
response = :y,
);
Create an MLJ machine for fitting our model:
mach = MLJBase.machine(model, iris[!, feature_columns], iris[!, :Species])
[34mMachine{SossMLJModel{,…}} @824[39m trained 0 times. args: 1: [34mSource @285[39m ⏎ `ScientificTypes.Table{AbstractVector{ScientificTypes.Continuous}}` 2: [34mSource @862[39m ⏎ `AbstractVector{ScientificTypes.Multiclass{3}}`
Fit the machine. This may take several minutes.
MLJBase.fit!(mach)
[34mMachine{SossMLJModel{,…}} @824[39m trained 1 time. args: 1: [34mSource @285[39m ⏎ `ScientificTypes.Table{AbstractVector{ScientificTypes.Continuous}}` 2: [34mSource @862[39m ⏎ `AbstractVector{ScientificTypes.Multiclass{3}}`
Construct the joint posterior:
predictor_joint = MLJBase.predict_joint(mach, iris[!, feature_columns])
typeof(predictor_joint)
SossMLJ.SossMLJPredictor{SossMLJModel{MLJBase.UnivariateFinite, Soss.Model{NamedTuple{(:X, :pool), T} where T<:Tuple, TypeEncoding(begin k = length(pool.levels) p = size(X, 2) β ~ Normal(0.0, 1.0) |> iid(p, k) η = X * β μ = NNlib.softmax(η; dims = 2) y_dists = UnivariateFinite(pool.levels, μ; pool = pool) n = size(X, 1) y ~ For((j->begin y_dists[j] end), n) end), TypeEncoding(Main.ex-example-multinomial-logistic-regression)}, NamedTuple{(:pool,), Tuple{CategoricalArrays.CategoricalPool{String, UInt8, CategoricalArrays.CategoricalValue{String, UInt8}}}}, typeof(Soss.dynamicHMC), Symbol, typeof(SossMLJ.default_transform)}, Vector{NamedTuple{(:β,), Tuple{Matrix{Float64}}}}, Soss.Model{NamedTuple{(:X, :pool, :β), T} where T<:Tuple, TypeEncoding(begin η = X * β μ = NNlib.softmax(η; dims = 2) y_dists = UnivariateFinite(pool.levels, μ; pool = pool) n = size(X, 1) y ~ For((j->begin y_dists[j] end), n) end), TypeEncoding(Main.ex-example-multinomial-logistic-regression)}, NamedTuple{(:X, :pool), Tuple{Matrix{Float64}, CategoricalArrays.CategoricalPool{String, UInt8, CategoricalArrays.CategoricalValue{String, UInt8}}}}}
Draw a single sample from the joint posterior:
single_sample = rand(predictor_joint)
150-element Vector{CategoricalArrays.CategoricalValue{String, UInt8}}: "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" "setosa" ⋮ "versicolor" "virginica" "virginica" "virginica" "virginica" "virginica" "virginica" "virginica" "virginica"
For each row in the dataset, construct the marginal posterior predictive distribution
predictor_marginal = MLJBase.predict(mach, iris[!, feature_columns])
150-element MLJBase.UnivariateFiniteVector{ScientificTypes.Multiclass{3}, String, UInt8, Float64}: UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.981, versicolor=>0.0188, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.955, versicolor=>0.0448, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.97, versicolor=>0.0298, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.952, versicolor=>0.0484, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.982, versicolor=>0.0178, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.981, versicolor=>0.019, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.974, versicolor=>0.0256, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.974, versicolor=>0.0264, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.946, versicolor=>0.0538, virginica=>0.0) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.942, versicolor=>0.058, virginica=>0.0) ⋮ UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.003, versicolor=>0.29, virginica=>0.707) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0002, versicolor=>0.081, virginica=>0.919) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0, versicolor=>0.0534, virginica=>0.947) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0, versicolor=>0.0478, virginica=>0.952) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0008, versicolor=>0.158, virginica=>0.841) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0002, versicolor=>0.181, virginica=>0.819) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0004, versicolor=>0.238, virginica=>0.761) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0008, versicolor=>0.0784, virginica=>0.921) UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0016, versicolor=>0.176, virginica=>0.822)
predictor_marginal is a UnivariateFiniteVector
typeof(predictor_marginal)
MLJBase.UnivariateFiniteVector{ScientificTypes.Multiclass{3}, String, UInt8, Float64} (alias for MLJBase.UnivariateFiniteArray{ScientificTypes.Multiclass{3}, String, UInt8, Float64, 1})
predictor_marginal
has one element for each row in the data set
@show size(predictor_marginal); @show size(iris, 1);
size(predictor_marginal) = (150,) size(iris, 1) = 150
Use cross-validation to evaluate the model with respect to the Brier score:
evaluate!(mach, resampling=CV(; nfolds = 4, shuffle = true), measure=brier_score, operation=MLJBase.predict)
┌─────────────────┬───────────────┬─────────────────────────────────────┐ │ _.measure │ _.measurement │ _.per_fold │ ├─────────────────┼───────────────┼─────────────────────────────────────┤ │ BrierScore @322 │ -0.0901 │ [-0.0908, -0.107, -0.0835, -0.0787] │ └─────────────────┴───────────────┴─────────────────────────────────────┘ _.per_observation = [[[-0.012, -0.56, ..., -0.335], [-0.0383, -0.0935, ..., -0.00251], [-0.00239, -0.00749, ..., -0.637], [-0.0015, -0.117, ..., -0.000858]]] _.fitted_params_per_fold = [ … ] _.report_per_fold = [ … ]
Use cross-validation to evaluate the model with respect to accuracy:
evaluate!(mach, resampling=CV(; nfolds = 4, shuffle = true), measure=accuracy, operation=MLJBase.predict_mode)
┌───────────────┬───────────────┬──────────────────────────────┐ │ _.measure │ _.measurement │ _.per_fold │ ├───────────────┼───────────────┼──────────────────────────────┤ │ Accuracy @394 │ 0.96 │ [0.974, 0.947, 0.946, 0.973] │ └───────────────┴───────────────┴──────────────────────────────┘ _.per_observation = [missing] _.fitted_params_per_fold = [ … ] _.report_per_fold = [ … ]
The cross-validated accuracy is greater than 90%, which is pretty good!
This page was generated using Literate.jl.