Extreme Deconvolution of Stellar DataΒΆ

Figure 6.12

Extreme deconvolution applied to stellar data from SDSS Stripe 82. The top panels compare the color distributions for a high signal-to-noise sample of standard stars (left) with lower signal-to-noise, single epoch, data (right). The middle panels show the results of applying extreme deconvolution to the single epoch data. The bottom panel compares the distributions of a color measured perpendicularly to the locus (the so-called w color is defined following Ivezic et al 2004). The distribution of colors from the extreme deconvolution of the noisy data recovers the tight distribution of the high signal-to-noise data.

../../_images/fig_stellar_XD_2.png ../../_images/fig_stellar_XD_1.png
number of noisy points:   (82003, 2)
number of stacked points: (13377, 2)
size after crossmatch: (12313, 5)
@pickle_results: computing results and saving to 'XD_stellar.pkl'
1: log(L) = 32880
    (17 sec)
2: log(L) = 33425
    (17 sec)
3: log(L) = 33756
    (17 sec)
4: log(L) = 34005
    (17 sec)
5: log(L) = 34186
    (17 sec)
6: log(L) = 34320
    (17 sec)
7: log(L) = 34427
    (17 sec)
8: log(L) = 34515
    (17 sec)
9: log(L) = 34590
    (17 sec)
10: log(L) = 34654
    (17 sec)
11: log(L) = 34711
    (17 sec)
12: log(L) = 34763
    (17 sec)
13: log(L) = 34810
    (19 sec)
14: log(L) = 34854
    (18 sec)
15: log(L) = 34893
    (18 sec)
16: log(L) = 34931
    (20 sec)
17: log(L) = 34969
    (18 sec)
18: log(L) = 35013
    (18 sec)
19: log(L) = 35061
    (19 sec)
20: log(L) = 35098
    (18 sec)
21: log(L) = 35131
    (20 sec)
22: log(L) = 35164
    (18 sec)
23: log(L) = 35195
    (20 sec)
24: log(L) = 35225
    (21 sec)
25: log(L) = 35255
    (19 sec)
26: log(L) = 35285
    (19 sec)
27: log(L) = 35315
    (19 sec)
28: log(L) = 35345
    (19 sec)
29: log(L) = 35374
    (21 sec)
30: log(L) = 35402
    (19 sec)
31: log(L) = 35429
    (18 sec)
32: log(L) = 35455
    (17 sec)
33: log(L) = 35479
    (17 sec)
34: log(L) = 35503
    (17 sec)
35: log(L) = 35525
    (17 sec)
36: log(L) = 35547
    (17 sec)
37: log(L) = 35567
    (17 sec)
38: log(L) = 35586
    (17 sec)
39: log(L) = 35605
    (17 sec)
40: log(L) = 35622
    (20 sec)
41: log(L) = 35639
    (20 sec)
42: log(L) = 35655
    (17 sec)
43: log(L) = 35671
    (17 sec)
44: log(L) = 35686
    (17 sec)
45: log(L) = 35700
    (17 sec)
46: log(L) = 35713
    (17 sec)
47: log(L) = 35726
    (18 sec)
48: log(L) = 35738
    (17 sec)
49: log(L) = 35750
    (17 sec)
50: log(L) = 35761
    (19 sec)
51: log(L) = 35771
    (20 sec)
52: log(L) = 35781
    (20 sec)
53: log(L) = 35790
    (19 sec)
54: log(L) = 35798
    (19 sec)
55: log(L) = 35806
    (20 sec)
56: log(L) = 35814
    (20 sec)
57: log(L) = 35821
    (19 sec)
58: log(L) = 35827
    (18 sec)
59: log(L) = 35834
    (19 sec)
60: log(L) = 35839
    (20 sec)
61: log(L) = 35845
    (19 sec)
62: log(L) = 35850
    (18 sec)
63: log(L) = 35855
    (18 sec)
64: log(L) = 35859
    (19 sec)
65: log(L) = 35864
    (18 sec)
66: log(L) = 35868
    (19 sec)
67: log(L) = 35872
    (19 sec)
68: log(L) = 35875
    (20 sec)
69: log(L) = 35879
    (19 sec)
70: log(L) = 35882
    (20 sec)
71: log(L) = 35885
    (19 sec)
72: log(L) = 35888
    (18 sec)
73: log(L) = 35891
    (18 sec)
74: log(L) = 35894
    (18 sec)
75: log(L) = 35896
    (19 sec)
76: log(L) = 35899
    (18 sec)
77: log(L) = 35901
    (18 sec)
78: log(L) = 35904
    (17 sec)
79: log(L) = 35906
    (18 sec)
80: log(L) = 35908
    (18 sec)
81: log(L) = 35910
    (18 sec)
82: log(L) = 35912
    (18 sec)
83: log(L) = 35914
    (19 sec)
84: log(L) = 35916
    (20 sec)
85: log(L) = 35917
    (19 sec)
86: log(L) = 35919
    (18 sec)
87: log(L) = 35921
    (19 sec)
88: log(L) = 35922
    (18 sec)
89: log(L) = 35924
    (18 sec)
90: log(L) = 35925
    (17 sec)
91: log(L) = 35927
    (17 sec)
92: log(L) = 35928
    (18 sec)
93: log(L) = 35930
    (18 sec)
94: log(L) = 35931
    (17 sec)
95: log(L) = 35932
    (18 sec)
96: log(L) = 35934
    (18 sec)
97: log(L) = 35935
    (18 sec)
98: log(L) = 35936
    (18 sec)
99: log(L) = 35937
    (18 sec)
100: log(L) = 35938
    (18 sec)
# Author: Jake VanderPlas
# License: BSD
#   The figure produced by this code is published in the textbook
#   "Statistics, Data Mining, and Machine Learning in Astronomy" (2013)
#   For more information, see http://astroML.github.com
#   To report a bug or issue, use the following forum:
#    https://groups.google.com/forum/#!forum/astroml-general

from __future__ import print_function, division

import numpy as np
from matplotlib import pyplot as plt

from astroML.density_estimation import XDGMM
from astroML.crossmatch import crossmatch
from astroML.datasets import fetch_sdss_S82standards, fetch_imaging_sample
from astroML.plotting.tools import draw_ellipse
from astroML.utils.decorators import pickle_results
from astroML.stats import sigmaG

#----------------------------------------------------------------------
# This function adjusts matplotlib settings for a uniform feel in the textbook.
# Note that with usetex=True, fonts are rendered with LaTeX.  This may
# result in an error if LaTeX is not installed on your system.  In that case,
# you can set usetex to False.
if "setup_text_plots" not in globals():
    from astroML.plotting import setup_text_plots
setup_text_plots(fontsize=8, usetex=True)

#------------------------------------------------------------
# define u-g-r-i-z extinction from Berry et al, arXiv 1111.4985
# multiply extinction by A_r
extinction_vector = np.array([1.810, 1.400, 1.0, 0.759, 0.561])

#----------------------------------------------------------------------
# Fetch and process the noisy imaging data
data_noisy = fetch_imaging_sample()

# select only stars
data_noisy = data_noisy[data_noisy['type'] == 6]

# Get the extinction-corrected magnitudes for each band
X = np.vstack([data_noisy[f + 'RawPSF'] for f in 'ugriz']).T
Xerr = np.vstack([data_noisy[f + 'psfErr'] for f in 'ugriz']).T

# extinction terms from Berry et al, arXiv 1111.4985
X -= (extinction_vector * data_noisy['rExtSFD'][:, None])


#----------------------------------------------------------------------
# Fetch and process the stacked imaging data
data_stacked = fetch_sdss_S82standards()

# cut to RA, DEC range of imaging sample
RA = data_stacked['RA']
DEC = data_stacked['DEC']
data_stacked = data_stacked[(RA > 0) & (RA < 10) &
                            (DEC > -1) & (DEC < 1)]

# get stacked magnitudes for each band
Y = np.vstack([data_stacked['mmu_' + f] for f in 'ugriz']).T
Yerr = np.vstack([data_stacked['msig_' + f] for f in 'ugriz']).T

# extinction terms from Berry et al, arXiv 1111.4985
Y -= (extinction_vector * data_stacked['A_r'][:, None])

# quality cuts
g = Y[:, 1]
mask = ((Yerr.max(1) < 0.05) &
        (g < 20))
data_stacked = data_stacked[mask]
Y = Y[mask]
Yerr = Yerr[mask]


#----------------------------------------------------------------------
# cross-match
#  the imaging sample contains both standard and variable stars.  We'll
#  perform a cross-match with the standard star catalog and choose objects
#  which are common to both.
Xlocs = np.hstack((data_noisy['ra'][:, np.newaxis],
                   data_noisy['dec'][:, np.newaxis]))
Ylocs = np.hstack((data_stacked['RA'][:, np.newaxis],
                   data_stacked['DEC'][:, np.newaxis]))

print("number of noisy points:  ", Xlocs.shape)
print("number of stacked points:", Ylocs.shape)

# find all points within 0.9 arcsec.  This cutoff was selected
# by plotting a histogram of the log(distances).
dist, ind = crossmatch(Xlocs, Ylocs, max_distance=0.9 / 3600)

noisy_mask = (~np.isinf(dist))
stacked_mask = ind[noisy_mask]

# select the data
data_noisy = data_noisy[noisy_mask]
X = X[noisy_mask]
Xerr = Xerr[noisy_mask]

data_stacked = data_stacked[stacked_mask]
Y = Y[stacked_mask]
Yerr = Yerr[stacked_mask]

# double-check that our cross-match succeeded
assert X.shape == Y.shape
print("size after crossmatch:", X.shape)


#----------------------------------------------------------------------
# perform extreme deconvolution on the noisy sample

# first define mixing matrix W
W = np.array([[0, 1, 0, 0, 0],    # g magnitude
              [1, -1, 0, 0, 0],   # u-g color
              [0, 1, -1, 0, 0],   # g-r color
              [0, 0, 1, -1, 0],   # r-i color
              [0, 0, 0, 1, -1]])  # i-z color

X = np.dot(X, W.T)
Y = np.dot(Y, W.T)

# compute error covariance from mixing matrix
Xcov = np.zeros(Xerr.shape + Xerr.shape[-1:])
Xcov[:, range(Xerr.shape[1]), range(Xerr.shape[1])] = Xerr ** 2

# each covariance C = WCW^T
# best way to do this is with a tensor dot-product
Xcov = np.tensordot(np.dot(Xcov, W.T), W, (-2, -1))


#----------------------------------------------------------------------
# This is a long calculation: save results to file
@pickle_results("XD_stellar.pkl")
def compute_XD(n_clusters=12, rseed=0, max_iter=100, verbose=True):
    np.random.seed(rseed)
    clf = XDGMM(n_clusters, max_iter=max_iter, tol=1E-5, verbose=verbose)
    clf.fit(X, Xcov)
    return clf


clf = compute_XD(12)

#------------------------------------------------------------
# Fit and sample from the underlying distribution
np.random.seed(42)
X_sample = clf.sample(X.shape[0])

#------------------------------------------------------------
# plot the results
fig = plt.figure(figsize=(5, 3.75))
fig.subplots_adjust(left=0.12, right=0.95,
                    bottom=0.1, top=0.95,
                    wspace=0.02, hspace=0.02)

# only plot 1/10 of the stars for clarity
ax1 = fig.add_subplot(221)
ax1.scatter(Y[::10, 2], Y[::10, 3], s=9, lw=0, c='k')

ax2 = fig.add_subplot(222)
ax2.scatter(X[::10, 2], X[::10, 3], s=9, lw=0, c='k')

ax3 = fig.add_subplot(223)
ax3.scatter(X_sample[::10, 2], X_sample[::10, 3], s=9, lw=0, c='k')

ax4 = fig.add_subplot(224)
for i in range(clf.n_components):
    draw_ellipse(clf.mu[i, 2:4], clf.V[i, 2:4, 2:4], scales=[2],
                 ec='k', fc='gray', alpha=0.2, ax=ax4)

titles = ["Standard Stars", "Single Epoch",
          "Extreme Deconvolution\n  resampling",
          "Extreme Deconvolution\n  cluster locations"]
ax = [ax1, ax2, ax3, ax4]

for i in range(4):
    ax[i].set_xlim(-0.6, 1.8)
    ax[i].set_ylim(-0.6, 1.8)

    ax[i].xaxis.set_major_locator(plt.MultipleLocator(0.5))
    ax[i].yaxis.set_major_locator(plt.MultipleLocator(0.5))

    ax[i].text(0.05, 0.95, titles[i],
               ha='left', va='top', transform=ax[i].transAxes)

    if i in (0, 1):
        ax[i].xaxis.set_major_formatter(plt.NullFormatter())
    else:
        ax[i].set_xlabel('$g-r$')

    if i in (1, 3):
        ax[i].yaxis.set_major_formatter(plt.NullFormatter())
    else:
        ax[i].set_ylabel('$r-i$')

#------------------------------------------------------------
# Second figure: the width of the locus
fig = plt.figure(figsize=(5, 3.75))
ax = fig.add_subplot(111)

labels = ['single epoch', 'standard stars', 'XD resampled']
linestyles = ['solid', 'dashed', 'dotted']
for data, label, ls in zip((X, Y, X_sample), labels, linestyles):
    g = data[:, 0]
    gr = data[:, 2]
    ri = data[:, 3]

    r = g - gr
    i = r - ri

    mask = (gr > 0.3) & (gr < 1.0)
    g = g[mask]
    r = r[mask]
    i = i[mask]

    w = -0.227 * g + 0.792 * r - 0.567 * i + 0.05

    sigma = sigmaG(w)

    ax.hist(w, bins=np.linspace(-0.08, 0.08, 100), linestyle=ls,
            histtype='step', label=label + '\n\t' + r'$\sigma_G=%.3f$' % sigma,
            density=True)

ax.legend(loc=2)
ax.text(0.95, 0.95, '$w = -0.227g + 0.792r$\n$ - 0.567i + 0.05$',
        transform=ax.transAxes, ha='right', va='top')

ax.set_xlim(-0.07, 0.07)
ax.set_ylim(0, 55)

ax.set_xlabel('$w$')
ax.set_ylabel('$N(w)$')

plt.show()