Multinomial logistic regression

Import the necessary packages:

using DataFrames
using Distributions
using MLJBase
using NNlib
using RDatasets
using Soss
using SossMLJ
using Statistics

In this example, we fit a Bayesian multinomial logistic regression model with the canonical link function.

Suppose that we are given a matrix of features X and a column vector of labels y. X has n rows and p columns. y has n elements. We assume that our observation vector y is a realization of a random variable Y. We define μ (mu) as the expected value of Y, i.e. μ := E[Y]. Our model comprises three components:

  1. The probability distribution of Y. We assume that each Yᵢ follows a multinomial distribution with k categories, mean μᵢ, and one trial.
  2. The systematic component, which consists of linear predictor η (eta), which we define as η := Xβ, where β is the column vector of p coefficients.
  3. The link function g, which provides the following relationship: g(E[Y]) = g(μ) = η = Xβ. It follows that μ = g⁻¹(η), where g⁻¹ denotes the inverse of g. In multinomial logistic regression, the canonical link function is the generalized logit function. The inverse of the generalized logit function is the softmax function. Therefore, when using the canonical link function, μ = g⁻¹(η) = softmax(η).

A multinomial distribution with one trial is equivalent to the categorical distribution. Therefore, the following two statements are equivalent:

  • Yᵢ follows a multinomial distribution with k categories, mean μᵢ, and one trial.
  • Yᵢ follows a categorical distribution with k categories and mean μᵢ.

Observe that the logistic regression model is a special case of the multinomial logistic regression model where k = 2.

In this model, the parameters that we want to estimate are the coefficients β. We need to select prior distributions for these parameters. For each βᵢ we choose a normal distribution with zero mean and unit variance. Here, βᵢ denotes the ith component of β.

We define this model using the Soss probabilistic programming library:

m = @model X,pool begin
    n = size(X, 1) # number of observations
    p = size(X, 2) # number of features
    k = length(pool.levels) # number of classes
    β ~ Normal(0.0, 1.0) |> iid(p, k) # coefficients
    η = X * β # linear predictor
    μ = NNlib.softmax(η; dims=2) # μ = g⁻¹(η) = softmax(η)
    y_dists = UnivariateFinite(pool.levels, μ; pool=pool) # `UnivariateFinite` is mathematically equivalent to `Categorical`
    y ~ For(j -> y_dists[j], n) # `Yᵢ ~ UnivariateFinite(mean=μᵢ, categories=k)`, which is mathematically equivalent to `Yᵢ ~ Categorical(mean=μᵢ, categories=k)`
end;

Import the Iris flower data set:

iris = dataset("datasets", "iris");

Define our feature columns:

feature_columns = [
    :PetalLength,
    :PetalWidth,
    :SepalLength,
    :SepalWidth,
]
4-element Vector{Symbol}:
 :PetalLength
 :PetalWidth
 :SepalLength
 :SepalWidth

Define our label column:

label_column = :Species
:Species

Convert the Soss model into a SossMLJModel:

model = SossMLJModel(;
    model       = m,
    predictor   = MLJBase.UnivariateFinite,
    hyperparams = (pool=iris.Species.pool,),
    infer       = dynamicHMC,
    response    = :y,
);

Create an MLJ machine for fitting our model:

mach = MLJBase.machine(model, iris[!, feature_columns], iris[!, :Species])
Machine{SossMLJModel{,…}} @824 trained 0 times.
  args: 
    1:	Source @285 ⏎ `ScientificTypes.Table{AbstractVector{ScientificTypes.Continuous}}`
    2:	Source @862 ⏎ `AbstractVector{ScientificTypes.Multiclass{3}}`

Fit the machine. This may take several minutes.

MLJBase.fit!(mach)
Machine{SossMLJModel{,…}} @824 trained 1 time.
  args: 
    1:	Source @285 ⏎ `ScientificTypes.Table{AbstractVector{ScientificTypes.Continuous}}`
    2:	Source @862 ⏎ `AbstractVector{ScientificTypes.Multiclass{3}}`

Construct the joint posterior:

predictor_joint = MLJBase.predict_joint(mach, iris[!, feature_columns])
typeof(predictor_joint)
SossMLJ.SossMLJPredictor{SossMLJModel{MLJBase.UnivariateFinite, Soss.Model{NamedTuple{(:X, :pool), T} where T<:Tuple, TypeEncoding(begin
    k = length(pool.levels)
    p = size(X, 2)
    β ~ Normal(0.0, 1.0) |> iid(p, k)
    η = X * β
    μ = NNlib.softmax(η; dims = 2)
    y_dists = UnivariateFinite(pool.levels, μ; pool = pool)
    n = size(X, 1)
    y ~ For((j->begin
                    y_dists[j]
                end), n)
end), TypeEncoding(Main.ex-example-multinomial-logistic-regression)}, NamedTuple{(:pool,), Tuple{CategoricalArrays.CategoricalPool{String, UInt8, CategoricalArrays.CategoricalValue{String, UInt8}}}}, typeof(Soss.dynamicHMC), Symbol, typeof(SossMLJ.default_transform)}, Vector{NamedTuple{(:β,), Tuple{Matrix{Float64}}}}, Soss.Model{NamedTuple{(:X, :pool, :β), T} where T<:Tuple, TypeEncoding(begin
    η = X * β
    μ = NNlib.softmax(η; dims = 2)
    y_dists = UnivariateFinite(pool.levels, μ; pool = pool)
    n = size(X, 1)
    y ~ For((j->begin
                    y_dists[j]
                end), n)
end), TypeEncoding(Main.ex-example-multinomial-logistic-regression)}, NamedTuple{(:X, :pool), Tuple{Matrix{Float64}, CategoricalArrays.CategoricalPool{String, UInt8, CategoricalArrays.CategoricalValue{String, UInt8}}}}}

Draw a single sample from the joint posterior:

single_sample = rand(predictor_joint)
150-element Vector{CategoricalArrays.CategoricalValue{String, UInt8}}:
 "setosa"
 "setosa"
 "setosa"
 "setosa"
 "setosa"
 "setosa"
 "setosa"
 "setosa"
 "setosa"
 "setosa"
 ⋮
 "versicolor"
 "virginica"
 "virginica"
 "virginica"
 "virginica"
 "virginica"
 "virginica"
 "virginica"
 "virginica"

For each row in the dataset, construct the marginal posterior predictive distribution

predictor_marginal = MLJBase.predict(mach, iris[!, feature_columns])
150-element MLJBase.UnivariateFiniteVector{ScientificTypes.Multiclass{3}, String, UInt8, Float64}:
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.981, versicolor=>0.0188, virginica=>0.0)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.955, versicolor=>0.0448, virginica=>0.0)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.97, versicolor=>0.0298, virginica=>0.0)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.952, versicolor=>0.0484, virginica=>0.0)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.982, versicolor=>0.0178, virginica=>0.0)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.981, versicolor=>0.019, virginica=>0.0)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.974, versicolor=>0.0256, virginica=>0.0)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.974, versicolor=>0.0264, virginica=>0.0)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.946, versicolor=>0.0538, virginica=>0.0)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.942, versicolor=>0.058, virginica=>0.0)
 ⋮
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.003, versicolor=>0.29, virginica=>0.707)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0002, versicolor=>0.081, virginica=>0.919)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0, versicolor=>0.0534, virginica=>0.947)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0, versicolor=>0.0478, virginica=>0.952)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0008, versicolor=>0.158, virginica=>0.841)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0002, versicolor=>0.181, virginica=>0.819)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0004, versicolor=>0.238, virginica=>0.761)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0008, versicolor=>0.0784, virginica=>0.921)
 UnivariateFinite{ScientificTypes.Multiclass{3}}(setosa=>0.0016, versicolor=>0.176, virginica=>0.822)

predictor_marginal is a UnivariateFiniteVector

typeof(predictor_marginal)
MLJBase.UnivariateFiniteVector{ScientificTypes.Multiclass{3}, String, UInt8, Float64} (alias for MLJBase.UnivariateFiniteArray{ScientificTypes.Multiclass{3}, String, UInt8, Float64, 1})

predictor_marginal has one element for each row in the data set

@show size(predictor_marginal); @show size(iris, 1);
size(predictor_marginal) = (150,)
size(iris, 1) = 150

Use cross-validation to evaluate the model with respect to the Brier score:

evaluate!(mach, resampling=CV(; nfolds = 4, shuffle = true), measure=brier_score, operation=MLJBase.predict)
┌─────────────────┬───────────────┬─────────────────────────────────────┐
│ _.measure       │ _.measurement │ _.per_fold                          │
├─────────────────┼───────────────┼─────────────────────────────────────┤
│ BrierScore @322 │ -0.0901       │ [-0.0908, -0.107, -0.0835, -0.0787] │
└─────────────────┴───────────────┴─────────────────────────────────────┘
_.per_observation = [[[-0.012, -0.56, ..., -0.335], [-0.0383, -0.0935, ..., -0.00251], [-0.00239, -0.00749, ..., -0.637], [-0.0015, -0.117, ..., -0.000858]]]
_.fitted_params_per_fold = [ … ]
_.report_per_fold = [ … ]

Use cross-validation to evaluate the model with respect to accuracy:

evaluate!(mach, resampling=CV(; nfolds = 4, shuffle = true), measure=accuracy, operation=MLJBase.predict_mode)
┌───────────────┬───────────────┬──────────────────────────────┐
│ _.measure     │ _.measurement │ _.per_fold                   │
├───────────────┼───────────────┼──────────────────────────────┤
│ Accuracy @394 │ 0.96          │ [0.974, 0.947, 0.946, 0.973] │
└───────────────┴───────────────┴──────────────────────────────┘
_.per_observation = [missing]
_.fitted_params_per_fold = [ … ]
_.report_per_fold = [ … ]

The cross-validated accuracy is greater than 90%, which is pretty good!


This page was generated using Literate.jl.