Source code for astroML.classification.gmm_bayes

"""
GMM Bayes
---------
This implements generative classification based on mixtures of gaussians
to model the probability density of each class.
"""

import warnings
import numpy as np

try:
    from sklearn.naive_bayes import _BaseNB
except ImportError:
    # work around for sklearn < 0.22
    from sklearn.naive_bayes import BaseNB

    class _BaseNB(BaseNB):
        pass

from sklearn.mixture import GaussianMixture
from sklearn.utils import check_array


[docs]class GMMBayes(_BaseNB): """GaussianMixture Bayes Classifier This is a generalization to the Naive Bayes classifier: rather than modeling the distribution of each class with axis-aligned gaussians, GMMBayes models the distribution of each class with mixtures of gaussians. This can lead to better classification in some cases. Parameters ---------- n_components : int or list number of components to use in the GaussianMixture. If specified as a list, it must match the number of class labels. Default is 1. **kwargs : dict, optional other keywords are passed directly to GaussianMixture """
[docs] def __init__(self, n_components=1, **kwargs): self.n_components = np.atleast_1d(n_components) self.kwargs = kwargs
def fit(self, X, y): X = self._check_X(X) y = np.asarray(y) n_samples, n_features = X.shape if n_samples != y.shape[0]: raise ValueError("X and y have incompatible shapes") self.classes_ = np.unique(y) self.classes_.sort() unique_y = self.classes_ n_classes = unique_y.shape[0] if self.n_components.size not in (1, len(unique_y)): raise ValueError("n_components must be compatible with " "the number of classes") self.gmms_ = [None for i in range(n_classes)] self.class_prior_ = np.zeros(n_classes) n_comp = np.zeros(len(self.classes_), dtype=int) + self.n_components for i, y_i in enumerate(unique_y): if n_comp[i] > X[y == y_i].shape[0]: warnstr = ("Expected n_samples >= n_components but got " "n_samples={0}, n_components={1}, " "n_components set to {0}.") warnings.warn(warnstr.format(X[y == y_i].shape[0], n_comp[i])) n_comp[i] = X[y == y_i].shape[0] self.gmms_[i] = GaussianMixture(n_comp[i], **self.kwargs).fit(X[y == y_i]) self.class_prior_[i] = float(np.sum(y == y_i)) / n_samples return self def _joint_log_likelihood(self, X): X = np.asarray(np.atleast_2d(X)) logprobs = np.array([g.score_samples(X) for g in self.gmms_]).T return logprobs + np.log(self.class_prior_) def _check_X(self, X): return check_array(X)