# Clustering of LINEAR data¶

## Figure 10.20¶

Unsupervised clustering analysis of periodic variable stars from the LINEAR data set. The top row shows clusters derived using two attributes (g - i and log P) and a mixture of 12 Gaussians. The colorized symbols mark the five most significant clusters. The bottom row shows analogous diagrams for clustering based on seven attributes (colors u - g, g - i, i - K, and J - K; log P, light-curve amplitude, and light-curve skewness), and a mixture of 15 Gaussians. See figure 10.21 for data projections in the space of other attributes for the latter case.

## Figure 10.21¶

Unsupervised clustering analysis of periodic variable stars from the LINEAR data set. Clusters are derived using seven attributes (colors u - g, g - i, i - K, and J - K; log P , light-curve amplitude, and light-curve skewness), and a mixture of 15 Gaussians. The log P vs. g - i diagram and log P vs. light-curve amplitude diagram for the same clusters are shown in the lower panels of figure 10.20.

________________________________________________________________________________
fig_LINEAR_clustering.py is not compiling:
________________________________________________________________________________

# Author: Jake VanderPlas
#   The figure produced by this code is published in the textbook
#   "Statistics, Data Mining, and Machine Learning in Astronomy" (2013)
#   To report a bug or issue, use the following forum:
from __future__ import print_function, division

import numpy as np
from matplotlib import pyplot as plt

from sklearn.mixture import GaussianMixture

from astroML.utils.decorators import pickle_results
from astroML.datasets import fetch_LINEAR_geneva

#----------------------------------------------------------------------
# This function adjusts matplotlib settings for a uniform feel in the textbook.
# Note that with usetex=True, fonts are rendered with LaTeX.  This may
# result in an error if LaTeX is not installed on your system.  In that case,
# you can set usetex to False.
if "setup_text_plots" not in globals():
from astroML.plotting import setup_text_plots
setup_text_plots(fontsize=8, usetex=True)

#------------------------------------------------------------
# Get the Geneva periods data
data = fetch_LINEAR_geneva()

#----------------------------------------------------------------------
# compute Gaussian Mixture models

filetemplate = 'gmm_res_%i_%i.pkl'
attributes = [('gi', 'logP'),
('ug', 'gi', 'iK', 'JK', 'logP', 'amp', 'skew')]
components = np.arange(1, 21)

#------------------------------------------------------------
# Create attribute arrays
Xarrays = []
for attr in attributes:
Xarrays.append(np.vstack([data[a] for a in attr]).T)

#------------------------------------------------------------
# Compute the results (and save to pickle file)
@pickle_results('LINEAR_clustering.pkl')
def compute_GaussianMixture_results(components, attributes):
clfs = []

for attr, X in zip(attributes, Xarrays):
clfs_i = []

for comp in components:
print("  - {0} component fit".format(comp))
clf = GaussianMixture(comp, covariance_type='full',
random_state=0, max_iter=500)
clf.fit(X)
clfs_i.append(clf)

if not clf.converged_:
print("           NOT CONVERGED!")

clfs.append(clfs_i)
return clfs

clfs = compute_GaussianMixture_results(components, attributes)

#------------------------------------------------------------
# Plot the results
fig = plt.figure(figsize=(5, 5))

class_labels = []

for i in range(2):
# Grab the best classifier, based on the BIC
X = Xarrays[i]
BIC = [c.bic(X) for c in clfs[i]]
i_best = np.argmin(BIC)

print("number of components:", components[i_best])

clf = clfs[i][i_best]
n_components = clf.n_components

# Predict the cluster labels with the classifier
c = clf.predict(X)
classes = np.unique(c)

class_labels.append(c)

# sort the cluster by normalized density of points
counts = np.sum(c == classes[:, None], 1)
size = np.array([np.linalg.det(C) for C in clf.covariances_])
weights = clf.weights_
density = counts / size

# Clusters with very few points are less interesting:
# set their density to zero so they'll go to the end of the list
density[counts < 5] = 0
isort = np.argsort(density)[::-1]

# find statistics of the top 10 clusters
Nclusters = 6

means = []
stdevs = []
counts = []

names = [name for name in data.dtype.names[2:] if name != 'LINEARobjectID']
labels = ['$u-g$', '$g-i$', '$i-K$', '$J-K$',
r'$\log(P)$', 'amplitude', 'skew',
'kurtosis', 'median mag', r'$N_{\rm obs}$', 'Visual Class']

assert len(names) == len(labels)

i_logP = names.index('logP')

for j in range(Nclusters):
flag = (c == isort[j])
counts.append(np.sum(flag))
means.append([np.mean(data[n][flag]) for n in names])
stdevs.append([data[n][flag].std() for n in names])

counts = np.array(counts)
means = np.array(means)
stdevs = np.array(stdevs)

# define colors based on median of logP
j_ordered = np.argsort(-means[:, i_logP])

# tweak colors by hand
if i == 1:
j_ordered[3], j_ordered[2] = j_ordered[2], j_ordered[3]

color = np.zeros(c.shape)
for j in range(Nclusters):
flag = (c == isort[j_ordered[j]])
color[flag] = j + 1

# separate into foureground and background
back = (color == 0)
fore = ~back

# Plot the resulting clusters
ax1 = fig.add_subplot(221 + 2 * i)
ax1.scatter(data['gi'][back], data['logP'][back],
c='gray', edgecolors='none', s=4, linewidths=0)
ax1.scatter(data['gi'][fore], data['logP'][fore],
c=color[fore], edgecolors='none',  s=4, linewidths=0)

ax1.set_ylabel(r'$\log(P)$')

ax2 = plt.subplot(222 + 2 * i)
ax2.scatter(data['amp'][back], data['logP'][back],
c='gray', edgecolors='none', s=4, linewidths=0)
ax2.scatter(data['amp'][fore], data['logP'][fore],
c=color[fore], edgecolors='none', s=4, linewidths=0)

#------------------------------
# set axis limits
ax1.set_xlim(-0.6, 2.1)
ax2.set_xlim(0.1, 1.5)
ax1.set_ylim(-1.5, 0.5)
ax2.set_ylim(-1.5, 0.5)

ax2.yaxis.set_major_formatter(plt.NullFormatter())
if i == 0:
ax1.xaxis.set_major_formatter(plt.NullFormatter())
ax2.xaxis.set_major_formatter(plt.NullFormatter())
else:
ax1.set_xlabel(r'$g-i$')
ax2.set_xlabel(r'$A$')

#------------------------------
# print table of means and medians directly to LaTeX format
print(r"\begin{tabular}{|l|lllllll|}")
print(r"   \hline")
for j in range(7):
print('   &', labels[j], end=" ")
print(r"\\")
print(r"   \hline")

for j in range(Nclusters):
print("   {0} ".format(j + 1), end=' ')
for k in range(7):
print(" & $%.2f \pm %.2f$ " % (means[j, k], stdevs[j, k]), end=' ')
print(r"\\")

print(r"\hline")
print(r"\end{tabular}")

#------------------------------------------------------------
# Second figure
fig = plt.figure(figsize=(5, 5))

attrs = ['skew', 'ug', 'iK', 'JK']
labels = ['skew', '$u-g$', '$i-K$', '$J-K$']
ylims = [(-1.8, 2.2), (0.6, 2.9), (0.1, 2.6), (-0.2, 1.2)]

for i in range(4):
ax.scatter(data['gi'][back], data[attrs[i]][back],
c='gray', edgecolors='none', s=4, linewidths=0)
ax.scatter(data['gi'][fore], data[attrs[i]][fore],
c=color[fore], edgecolors='none', s=4, linewidths=0)
ax.set_xlabel('$g-i$')
ax.set_ylabel(labels[i])

ax.set_xlim(-0.6, 2.1)
ax.set_ylim(ylims[i])

#------------------------------------------------------------
# Save the results
#
# run the script as
#
#   >\$ python fig_LINEAR_clustering.py --save
#
# to output the data file showing the cluster labels of each point
import sys
if len(sys.argv) > 1 and sys.argv[1] == '--save':
filename = 'cluster_labels.dat'

print("Saving cluster labels to", filename)

from astroML.datasets.LINEAR_sample import ARCHIVE_DTYPE
new_data = np.zeros(len(data),
dtype=(ARCHIVE_DTYPE + [('2D_cluster_ID', 'i4'),
('7D_cluster_ID', 'i4')]))

for name in data.dtype.names:
new_data[name] = data[name]
new_data['2D_cluster_ID'] = class_labels[0]
new_data['7D_cluster_ID'] = class_labels[1]

fmt = ('%.6f   %.6f   %.3f   %.3f   %.3f   %.3f   %.7f   %.3f   %.3f   '
'%.3f    %.2f     %i     %i      %s          %i              %i\n')

F = open(filename, 'w')
F.write('#    ra           dec       ug      gi      iK      JK     '
'logP       Ampl    skew      kurt    magMed    nObs  LCtype  '
'LINEARobjectID  2D_cluster_ID   7D_cluster_ID\n')
for line in new_data:
F.write(fmt % tuple(line[col] for col in line.dtype.names))
F.close()

plt.show()