Analyzing Microarray Data: A tutorial in unsupervised machine learning

A DNA Microarray

This is a tutorial that demonstrates some high-level concepts in bioinformatics, undirected machine learning, and data visualization. In particular, it performs a K-means clustering analysis on microarray data from 6153 genes at 7 time points, and demonstrates several methods for visualizing the resulting data.

The code of this notebook is all written in Python 3, and a copy of this notebook is available at My GitHub.

This page is a work-in-progress, please forgive any errors or areas which are unclear. You can send suggestions for how to improve it to the email address given on the aforementioned GitHub.

**To run this notebook on your own computer:**
  1. Install a Python distribution such as [Anaconda](https://store.continuum.io/cshop/anaconda/), or separately download all of the dependencies listed in the readme file in [My GitHub](https://github.com/GSimkus/bioinformatics_tutorial_clustering/).
  2. Download the source for this notebook to your computer from [here](https://github.com/GSimkus/bioinformatics_tutorial_clustering/blob/master/yeast_clustering_tutorial.ipynb)
    (Click the download button, then save the page as "yeast_clustering_tutorial.ipynb" and remove any ".txt" that appears)
  3. (Optional) Download and install FFMPEG by following [this guide](https://github.com/adaptlearning/adapt_authoring/wiki/Installing-FFmpeg). You'll need this to generate the animations at the end of the notebook
  4. Launch "Jupyter Notebook", then select "yeast_clustering_tutorial.ipynb" in the browser

The first thing to do in any script is to "import" the libraries we'll use:

While there are a number of extremely high quality libraries out there that can perform most of tasks we cover below, I wanted to provide some examples of simple, easy-to-follow code that you can write yourself. Instead of using a library like scikit-learn, and running a command like "sklearn.cluster.KMeans()" and having everything handled for you; I want to show you how to design a machine learning algorithim from scratch, and to use it to learn things about real-life data.

There are, however, a few things we will get help with - like rendering our graphs, working with math, and structuring the data itself - so that we can focus on the bioinformatics.

In this case, the libraries we'll need are:

  • math which is python's built in library for doing math with. We will need the square root and ceiling functions.
  • numpy describes itself as "the fundamental package for scientific computing with Python". It's not wrong.
  • pandas is a library that makes working with structured data like tables much easier
  • matplotlib provides all sorts of tools for visualizing data, like animation, the module we import here

Also being done here:

  • %matplotlib inline is a "magic function" in Jupyter that ensures that any plots we make in this notebook are rendered in the browser, in HTML.
  • rc('animation', html='html5') sets the rendering settings on the animations and videos to be compatible with a webbrowser
  • np.random.seed(1) sets the random number generator (which comes from numpy) to have a "seed" value of 1 (there's no reason it has to be 1). This means that the outputs will be the same every time we run this notebook, which is not only a good practice in science (since an experiment isn't worth much unless it's reproducible), but is very useful for this tutorial, since it ensures you and I will see the same results when we run the code.

If you are running this on your own computer, you can run a cell and move to the next one by pressing shift and enter, or pressing the button to step forward at the top of the page.

In [1]:
from math import sqrt, ceil
import numpy as np
import pandas
import matplotlib.pyplot as plt
from matplotlib.pyplot import bar
from matplotlib import animation, rc, cm

%matplotlib inline
rc('animation', html='html5')
np.random.seed(1)

Next we will load our data and take a few peaks.

The data we will be using today comes from a DNA microarray experiment performed by Joseph L. DeRisi, Vishwanath R. Iyer, and Patrick O. Brown in 1997.

You can find all of the raw data in the "Gene Expression Omnibus" database, which is maintained by NCBI. Our data specifically is series GSE28.

The series consists of yeast that was placed into an incubator and allowed to slowly run out of food. Samples were collected at the start, 9.5hours later, and then at the 11.5, 13.5, 15.5, 18.5, and 20.5 hour marks for a total of 7 timepoints.

For each sample the amount of mRNA present for each of the 6153 yeast genes was measured, which tells us how much protein was produced for each and every one of those genes.

I've taken the raw data from the website and done all of the processing already, for simplicity's sake I'm not including the process of compiling all of the raw machine output data values into a final dataset (at least for now, check back for updates!), and instead we'll skip right to the set I compiled myself and which is located on the Github Repo where you (hopefully) downloaded this notebook from.

A note on formatting:

In my compiled data I have taken the ratio of the reading at every timepoint to the reference reading taken at the start of the experiment, and then taken the base 2 logarithm of those ratios.

So 0 at any timepoint means no difference in the level of protein between that time and the start of the experiment, a -1 indicates half the level of expression, a +1 indicates twice the expression, a +2 indicates four times, etc.

Each row in the dataset starts with a unique index number, from 0 to 6152, then contains the standard "systematic" name for the gene/protein, and the "common" name for the gene (which doesn't always exist, and many entries in this column read "NaN" or "Not a Number", which is used to fill in blanks).

Exploring the data

To begin, we'll load the data. Pandas makes this easy with a "read_csv" function. Note that it doesn't matter that the file doesn't actually end in ".csv" here, the function doesn't care as long as the file it's told to read is formatted similarly to a csv or "comma-separated values" file.

So since my file contains nothing but values that are separated by commas for each column, and has a new line for every row, there's no issue; we can simply read the file with the function from pandas.

In [2]:
dataset = pandas.read_csv("dataset.txt")

Now that the data is loaded, we can start exploring it, yay!

Now, since there are over 6000 rows I won't display it all at once (though if you are running this in your own notebook, feel free to do that).

Instead, we can look at only the first few points of data by using dataset.head(), a handy feature provided by pandas.

You can see the things I talked about above, the index number, the systematic name of the Open Reading Frame or "ORF", the common name for each gene, and a bunch of a numeric values - one value for each of the time points.

In [3]:
dataset.head()
Out[3]:
ORF GENE_NAME 0 9.5 11.5 13.5 15.5 18.5 20.5
0 YHR007C ERG11 0.164815 0.026563 0.396236 -0.187368 -0.237240 -1.431412 -1.206855
1 YAL051W OAF1 -0.056694 0.294169 0.394323 0.240178 0.223565 -0.209674 -0.120964
2 YAL054C ACS1 -1.131074 -0.853711 -0.136515 -0.213369 -0.448194 1.666371 3.842899
3 YAL056W GPE2 0.030979 0.544231 0.248595 -0.001895 -0.199895 0.074518 0.538777
4 YAL001C TFC3 -0.431456 0.828072 -0.128579 0.004986 -0.407332 -0.086866 -0.111004

We could also view a random sample of points with the dataset.sample(). Let's take 10 random rows and save them as a new variable, called "sample". You'll notice that when we do, the index numbers of our sample won't be in order anymore, a byproduct of the random selecting of values.

In [4]:
dataset.sample(10)
Out[4]:
ORF GENE_NAME 0 9.5 11.5 13.5 15.5 18.5 20.5
5091 YFR015C GSY1 -0.653442 -0.571970 0.608941 1.901069 1.923572 3.443486 2.864934
2139 YGR180C RNR4 0.548868 -0.065234 0.604228 -0.049093 0.118624 0.833885 -0.210440
5745 YMR170C ALD2 0.148068 -0.041179 0.234122 0.455728 0.797618 3.731788 2.340016
544 YGL230C NaN -0.822688 -1.618910 -0.106128 -0.329833 -0.674039 0.476801 -0.098501
200 YDL052C SLC1 -0.010659 0.381363 -0.121106 -0.279021 -0.597331 -1.450360 -1.451589
644 YHR025W THR1 0.094112 -0.294995 0.660237 0.025997 -0.388282 -0.933230 -1.087271
2855 YOL075C NaN 0.126298 0.028166 0.485701 0.128324 -0.002184 0.533085 0.781960
6152 YIL172C NaN -0.252871 -0.461282 0.558760 -0.373240 -0.105231 0.675211 0.791520
5606 YKR068C BET3 0.188102 -0.094095 0.135375 -0.020596 -0.035624 0.046611 -0.227153
4298 YNL211C NaN -0.109304 -0.457173 -0.017317 -0.230250 -0.420622 0.168594 -0.269155

Thanks to pandas it's also easy to look at only the time columns, hiding the names of the genes

In [5]:
timepoints = ["0", "9.5", "11.5", "13.5", "15.5", "18.5", "20.5"]
dataset[timepoints].head()
Out[5]:
0 9.5 11.5 13.5 15.5 18.5 20.5
0 0.164815 0.026563 0.396236 -0.187368 -0.237240 -1.431412 -1.206855
1 -0.056694 0.294169 0.394323 0.240178 0.223565 -0.209674 -0.120964
2 -1.131074 -0.853711 -0.136515 -0.213369 -0.448194 1.666371 3.842899
3 0.030979 0.544231 0.248595 -0.001895 -0.199895 0.074518 0.538777
4 -0.431456 0.828072 -0.128579 0.004986 -0.407332 -0.086866 -0.111004

Finally, we can get summaries like this one:

In [6]:
dataset.describe()
Out[6]:
0 9.5 11.5 13.5 15.5 18.5 20.5
count 6148.000000 6145.000000 6141.000000 6149.000000 6145.000000 6151.000000 6145.000000
mean -0.094249 -0.339335 0.127342 -0.200178 -0.253921 0.161199 -0.188748
std 0.289747 0.372278 0.347693 0.366759 0.507747 0.901469 0.923154
min -2.397684 -4.865614 -2.218738 -3.827819 -6.309855 -3.351044 -9.364864
25% -0.228340 -0.509487 -0.102662 -0.409451 -0.535623 -0.361172 -0.671828
50% -0.060005 -0.314930 0.113594 -0.208282 -0.267249 0.160066 -0.189960
75% 0.086474 -0.133189 0.353468 0.000895 0.011375 0.681874 0.285283
max 1.917538 1.379943 2.754888 1.930286 2.653681 4.297702 3.958122

It looks like the standard deviation is very small compared to the min and max for each of the time points, and that our 25th, and 75th percentiles are very close to eachother. This suggests that the majority of data is all located in one region.

Visualizing the data

It's all well and good to look at things like standard deviations and percentiles, but those are hard to immediately understand in the way that a graph is.

We'll start with a histogram, drawing the 6 timepoints after the start of the experiment.

We'll even make this a function, so we can use it later without re-typing it all.

Read the comments (text after a "#") to follow along.

In [7]:
# The first line is a function definition, we'll name it "histograms"
# Our funciton will take a dataset, and make histograms out of it
def histograms(dataset):
    
    # Creating an empty figure on which we will plot things
    # the "figsize=(15, 16)" bit just set a custom size for
    # the canvas, to prevent squishing
    fig = plt.figure(figsize=(15,16))
    
    # Now we're going to go through every timepoint  *except* the first one 
    # and also number them. Python has a built-in "enumerate" function that
    # assigns numbers to things, and we'll use a "for" loop to move through 
    # the timepoints and do our work. 
    for i, time in enumerate(timepoints[1:]):
    
        # add_subplot specifies a number of rows and columns that will be drawn,
        # it also needs to be told which subplot is currently being worked on, 
        # which is what the i+1 is for. The +1 is necessary since enumerate starts
        # on 0, but the top-left subplot is numbered 1. 
        fig.add_subplot(3,2,i+1)
        
        # Pretty simple, get the time column from the dataset and plot a histogram
        # The histogram has 60 bins, and covers the range of -3 to 3
        dataset[time].plot.hist(bins=60, range = (-3, 3))
        
        # Title the image
        plt.title("{} hour timepoint".format(time))
        
        # Set a standard y-axis height for easier comparison
        plt.ylim([0,800])
        
# Now we'll call the function
histograms(dataset)

Of course, we can also visualize the distribution with a KDE plot, and the code is almost identical

In [8]:
def kdes(dataset):

    fig = plt.figure(figsize=(15, 16))
    
    for i, time in enumerate(timepoints[1:]):
        fig.add_subplot(3,2,i+1)
        
        dataset[time].plot.kde()
        plt.title("{} hour timepoint".format(time))
        
        plt.ylim([0,1.4])
        plt.xlim([-3, 3])
        
        
kdes(dataset)

Even in the final timepoints vast majority of the genes never change expression very much.

Since we're interested in the genes that are actually affected by starvation let's remove all the points whose expression never changes by more than, say, 40% (0.5 on our scale).

We'll start by getting every row where all the values along the row axis (that is, along axis 1) have an absolute value of less than 0.5.

Note: we also could look for where not any of the values is greater than 0.5

In pseudo code I'm saying:

"Get the part of the whole dataset where ALL of the timepoint-columns in each row (axis=1) have absolute values less that 0.5" and call it the "duds"

In Python that looks like:

In [9]:
duds = dataset[(dataset[timepoints].abs() < 0.5).all(axis=1)]

we may as well describe it, too

In [10]:
duds.describe()
Out[10]:
0 9.5 11.5 13.5 15.5 18.5 20.5
count 867.000000 867.000000 867.000000 867.000000 867.000000 867.000000 867.000000
mean -0.023672 -0.215698 0.092564 -0.149481 -0.149086 0.050245 -0.107836
std 0.163472 0.189740 0.205860 0.180489 0.215854 0.272557 0.242434
min -0.486538 -0.499760 -0.499812 -0.499056 -0.498712 -0.494143 -0.499947
25% -0.130129 -0.353815 -0.054641 -0.278146 -0.323819 -0.170038 -0.294133
50% -0.022284 -0.241631 0.090328 -0.154879 -0.176726 0.059820 -0.145969
75% 0.088877 -0.114544 0.254383 -0.031825 -0.003758 0.295878 0.058659
max 0.450676 0.494666 0.498588 0.432919 0.489353 0.499965 0.498671

Wow, 867 genes that barely changed expression at all. Let's drop them from our data and then look at the new histograms.

In [11]:
trimmed_data = dataset.drop(duds.index)
histograms(trimmed_data)

It seems like most of the changes happen in the last 5 hours. Interestingly, the middle parts of the last 2 histograms, while smaller, are not completely empty. Huh.

Maybe a different kind of graph would provide more insight.

Since the data is made up of values and timepoints, it should be pretty easy to plot a line graph, but since pandas likes to plot with the "index" as the x axis, we're going to want to transpose the data, first.

Since plotting 4000+ lines on one graph would be a complete mess, we'll plot a sample of 10 points.

In [12]:
# First we grab a sample, just like we did earlier but now with 20 points
sample = trimmed_data.sample(10)

# Then we transpose it
transposed_sample = sample[timepoints].T

# And finally we plot. The semicolon prevents that annoying text string from appearing
transposed_sample.plot.line(legend=False);

Aha! It seems like many of the points jump around a lot, including a bunch that spike early and then drop back to 0. That means either this kind of variation is totally normal of any gene, or there's something interesting going on.

Note: If we wanted to focus on just the extreme genes we could drop more data points, but it's possible that some of these shifts are important parts of the yeast's biology, and if we drop too much data we risk oversimplifying.

That said, it would be nice to somehow categorize these patterns and get a list of proteins whose expression levels over time are similar.
For that we can use:

Unsupervised Machine Learning

Specifically, the K-means clustering algorithm, sometimes called "Lloyd's Algorithm", which is a fast and very popular tool for clustering data into categories. It's far from perfect, with algorithms like the "Expectation-Maximization clustering" algorithim tending to obtain better results with more room for subtlety, but it's extremely fast and can be used in most scenarios with little to no issue.

To see it in action, let's whip up some clusters:

In [13]:
# First we'll pick some points to use as the true centers of the clusters
xs, ys = [-9, -4, 9], [9, -6, 4]

# Then we'll pick some sigma values
sigmas = [4, 3, 4]

# And finally create some blank lists
xdata, ydata = [], []

# All that's left is to make a few clusters with a for loop
for i in range(3):
    # We'll get a bunch of x and y coordinates from a normal distribution
    coordsx = list(np.random.normal(xs[i], sigmas[i], (30)))
    coordsy = list(np.random.normal(ys[i], sigmas[i], (30)))
    
    # Saving them for later
    xdata += coordsx
    ydata += coordsy
    
# Finally, we're going to "zip" together the x and y coordinates into a bunch of x,y points
points = [(x, y) for x,y in zip(xdata, ydata)]

# And add those points to a new DataFrame
example = pandas.DataFrame(points, columns= ["x", "y"])
example.head()
Out[13]:
x y
0 -7.851853 0.193370
1 -4.808001 3.351480
2 -3.392786 11.425495
3 -9.116417 9.013564
4 -8.763465 4.753542

Let's take a look at our points on a scatterplot!

In [14]:
fig, ax = plt.subplots()
ax.plot(example["x"],example["y"], "o", color='grey',  markersize = 4);

Looking good!

The next step is to decide how many clusters we want out algorithim to find - the "K" in "K-means".

I'm going to say 3, for what are hopefully obvious reasons.

In [15]:
k = 3

The K-means algorithm is fairly simple, it just has 3 steps:

1. Pick k points to use as centers of clusters
2. Put every point into a cluster with the center it is closest to
3. Move the centers to the middle of the clusters in the above step

And then the algorithm repeats step 2 and 3 for however many times you want it to.

Let's get started.

Step1: Pick a random set of points to be used as the centers of the first clusters

In [16]:
# Choosing 3 random rows from our example data
sample = example.sample(k)

# And getting the coordinates of those points out as a list
centerslist = sample[["x", "y"]].values

#Finally, we'll plot the centers as big red stars on the same axis as the rest of the points
ax.plot(centerslist.T[0],centerslist.T[1], '*', color = 'red', markersize = 15);
fig
Out[16]:

Step 2: do a "Points to Clusters" step

Every point is added to whatever cluster is closest, and we should make this a new function.

We should first define a function that tells us the distance between points, and one that determines what center is closest to a given point

In [17]:
# The distance between point 1 and point 2
def distance(p1, p2):
    # Is the cartesian distance between the points!
     return sqrt(sum([(x-i)**2 for x,i in zip(p1, p2)]))

# Point to Cluster takes a single point
def p2c(point):
    point = list(point)
    # Looks at the distance to all the centers
    distances = [distance(point, center) for center in centerslist]
    #And returns the index of the smallest distance
    return distances.index(min(distances))

Now we can apply this p2c function to every row in the frame

In [18]:
# To save some typing I will shorthand the columns I want to plot
cols = ["x", "y"]

example['Cluster'] = example[cols].apply(p2c, axis=1, raw=True)

# Let's see the numbers of points in each cluster
example.groupby("Cluster").count()
Out[18]:
x y
Cluster
0 11 11
1 53 53
2 26 26

That's not great, each cluster should have exactly 40 points. Let's see the graph at this stage:

In [19]:
# First we'll need some distinct colors for the points and the chosen centers
colors = ['fuchsia', 'orange', 'lime']
centercolors = ['indigo', 'orangered', 'darkgreen']

# We may as well define a new function.
# The scatterplot function takes data and assumes "k" number of clusters
def scatterplot(data, k=k):
    
    # Creates a drawing space and axis
    fig, ax = plt.subplots()

    # Then  goes through a loop for each of k clusters
    for i in range(k):
    
        # getting the rows of the points for each cluster
        cluster = example[(example["Cluster"] == i)]

        # plotting the points
        ax.plot(cluster.x, cluster.y, "o", label="Cluster {}".format(i), color = colors[i], ms=4)

        # and adding the centers again
        ax.plot(centerslist.T[0][i],centerslist.T[1][i], '*', color = centercolors[i], markersize = 15);
    
    return fig, ax

# I'm only assigning fig and ax as variables because of some trickery I will employ later.
fig, ax = scatterplot(example);

Phew, this is messy, everything happened to be in a vertical line, and the resulting clusters are nothing like what we want!
Let's hurry up and complete the algorithm.

Step 3: do a "Clusters to Points" step

Move the center to the middle of the new cluster by finding the average of the entire cluster, and making that the new center point.

Sidenote: if the dataset was extremely large we could save time here by using some kind of Monte Carlo method to take a statistical sample of points and finding their average, which can dramatically speed up an already fast algorithm

In [20]:
# We input the dataframe of our points       
def c2p(data, columns):
    
    # Chuck out that old centerslist
    centerslist = []
    
    # Go through all k centers
    for i in range(k):
        
        # And add the mean of each coordinate in each column to the centerlist
        centerslist += [data[(data.Cluster == i)][columns].mean().tolist()]
    return np.array(centerslist)
centerslist = c2p(example, cols)

That's the function, defined and applied. Let's add these new centers as pentagons.

In [21]:
for i in range(k):
    ax.plot(centerslist.T[0][i],centerslist.T[1][i], 'p', color = centercolors[i], markersize = 15);
fig
Out[21]:

Those centers look more appropriate. The green center has shifted upwards, and the orange center has moved right. Let's repeat steps 2 and 3 and then plot the graph again.

In [22]:
# First we apply the p2c algorithm to all the rows
example['Cluster'] = example[cols].apply(p2c, axis=1, raw=True)

# Then we get new centers
centerslist = c2p(example, cols)

# And finally, we plot the new graph
fig, ax = scatterplot(example);

It's certainly getting there! We can repeat that process again to get the centers closer to their true values.

In [23]:
example['Cluster'] = example[cols].apply(p2c, axis=1, raw=True)
centerslist = c2p(example, cols)

for i in range(k):
    ax.plot(centerslist.T[0][i],centerslist.T[1][i], 'p', color = centercolors[i], markersize = 15);
scatterplot(example);

We can loop the algorithm 20 more times, and it will improve with each loop.

In [24]:
for i in range(20):
    example['Cluster'] = example[cols].apply(p2c, axis=1, raw=True)
    centerslist = c2p(example, cols)
    
scatterplot(example);

Looking good! Feel free to change all the variables above and watch the process with different centers, numbers of clusters, standard deviations, etc.

Otherwise,

It's time to get back to our main dataset

We can choose any number of clusters we want, so let's choose something like 4 clusters for this data. I also think that typing "trimmed_data" all the time is going to be annoying, so we're going to just call it "data".

In [25]:
k = 4
data = trimmed_data

We're also going to want a good list of colors to use. My favourite from the standard set available in matplotlib is "tab10".

In [26]:
colors = cm.tab10

Since we've got the "p2c" function already, we're ready to do the first 2 steps of the algorithm:

Step1:

  • Choose K rows,
  • get their coordinates into a list
In [27]:
sample = data.sample(k)
centerslist = sample[timepoints].values

Step2:

  • Assign the rest of the points into clusters
  • We may as well take a look at counts
In [28]:
data['Cluster'] = data[timepoints].apply(p2c, axis=1, raw=True)
data.groupby("Cluster").count()
Out[28]:
ORF GENE_NAME 0 9.5 11.5 13.5 15.5 18.5 20.5
Cluster
0 1385 790 1380 1377 1373 1381 1377 1383 1377
1 403 235 403 403 403 403 403 403 403
2 2278 1326 2278 2278 2278 2278 2278 2278 2278
3 1220 891 1220 1220 1220 1220 1220 1220 1220

So we've got our 5286 points arranged into 4 groups of various sizes, and it would be awesome to look at them, but unfortunately each row has 7 values, and since 7 is bigger than 2 it's not very easy to display my data on a screen.

This is one of the Curses of Dimensionality. We can't make those nice little 2D scatter plots when working with 7-dimensional data.

Fortunately, we can plot a line graph from 7 points, and that's a starting point.

Now like I mentioned above, to graph 5000+ lines would be to make an image that's completely incomprehensible, so how about we only plot the arithmetic mean of each cluster's expression at each time.

And to be extra fancy, we'll add a semi-transparent area that shows one standard deviation.

In [29]:
# I'm going to add a "clusters_to_plot" parameter with a default value of range k, which might be useful later.
def plotmeans(data, clusters_to_plot = range(k)):
    fig = plt.figure()
    ax = fig.add_axes([0, 0, 2, 1])
    # That list there simply stretches the x axis out, it will make the graph look nicer.
    
    # I'm also tired of graphs with black lines on the top and left, let's remove those.
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # A nice title, and some axis labels
    plt.title("Clustered Yeast Expression Ratios")
    plt.xlabel("Time since innoculation\n/h")
    plt.ylabel("Expression\nRatio", rotation = 0, labelpad = 25)
    # the "\n" is code for "new line"
    
    # This takes the list of column names and turns them into floats 
    # (aka numbers with decimals) so we can accurately plot time on the x axis
    xdata=[float(i) for i in timepoints]
    
    # We'll also go ahead and set the axis to display those ticks
    ax.set_xticks(xdata)
    ax.set_xlim(0, 20.5)

    # And finally, the meat
    # for every integer from 0 up to the number of clusters
    for i in clusters_to_plot:
        
        # Take just the datapoints whose cluters matches that integer        
        cluster = data[(data.Cluster == i)][timepoints]
        
        # Calculate a mean and a standard deviation
        mean = cluster.mean()
        std = cluster.std()
        
        # Plot the time value on the x axis, the mean expression on the y, give it a label, and give it a color
        ax.plot(xdata, mean, label="Cluster " + str(i), color= colors(i))
        
        # Then, for every timepoint fill the space in between one deviaion below and above, 
        # make it 80% transparent, and give it the same color
        plt.fill_between(xdata, mean-std, mean+std, alpha=0.2, facecolor = colors(i))
    
    # Lastly, slap a legend in the upper lft corner
    plt.legend(ncol=1, loc = 'upper left');
    return fig
plotmeans(data);

It seems like cluster 1 (orange) has already picked up a lot of the genes that increase in expression, but the other 4 clusters are all in a jumble, especially at the 15.5 and 18.5 hour timepoints.

We can also try to plot a KDE curve of each of the clusters at each of the timepoints:

In [30]:
# One note is that "cluster_to_plot" variable being set to None by default. 
def manykde(data, timepoints, clusters_to_plot = None):
    fig = plt.figure(figsize=(15, 16))
    ax = []
    
    # Here we are going to see if the user put anything in for the "clusters_to_plot" parameter
    # And if they didn't, set it to range(k)
    if clusters_to_plot == None:
        clusters_to_plot = range(k)
    # We had to do it this way because making "range(k)" the default would have 
    # used the k-value from when the function was defined, and not from when the function was called.
    
    # Moving on,

    # For each timepoint get a number and the name of the timepoint
    for i, time in enumerate(timepoints[1:]):

        # Make a subplot in a 3-row, 2 column chart with the position being that number
        # Add the subplot to a list
        ax += [fig.add_subplot(3,2,i+1)]

        # And then for each cluster
        for j in clusters_to_plot:

            # Get the points where the cluster matches the one we want to plot
            s = data[data['Cluster'] == j][timepoints[i]]

            # Plot a KDE curve
            s.plot.kde(label=j, color=colors(j), ax=ax[i], linewidth = 3)

            # Title, axis label, limits, legend
            plt.title("{} hour timepoint".format(time))
            plt.xlabel("Expression\nRatio")
            plt.ylim([0,2])
            plt.xlim([-4, 4])
            plt.legend()

        # This handy-dandy little function automates the laying-out of the subgraphs, saving us the work
        fig.tight_layout()
    return fig, ax
manykde(data, timepoints);
    

And again, it looks like orange represents high-expression genes, and the rest are an overlapping jumble. One thing we didn't see earlier is that cluster 3, the red cluster, has a very narrow density at the beginning, timepoint. It also looks like clusters 0 and 2 are nigh-identical at each timepoint.

Time to run the algorithm a few times, maybe 20 cycles?

Step 3+: Move the centers and repeat the process

In [31]:
for i in range(25):
    centerslist = c2p(data, timepoints)
    data['Cluster'] = data[timepoints].apply(p2c, axis=1, raw=True)

After 25 cycles everything should be clustered and we can try some visualizations to see how we've done.

Just for the sake of variety let's try a plotting technique that's completely different: instead of taking all of the rows and plotting them - which would take longer than I want to wait - we can take a sample of, say, 200 from each cluster and plot the clusters individually:

In [32]:
def manylines(data, samples = 200):
    fig = plt.figure(frameon=False, figsize=(15, 15))

    plt.title("Lineplots of {} semi-transparent samples from each cluster".format(samples), y=1.05, fontsize='x-large')
    plt.box('off')
    plt.axis('off')

    ax=[]
    for i in range(k):
        ax.append(fig.add_subplot(int(ceil(k/2)),2,i+1))

        ax[i].set_ylim([-4,4])
        ax[i].set_title("Cluster {}".format(i))
        plt.box('off')
        plt.axis('off')
        plt.setp(ax[i].get_xticklabels(), visible=False)

        cluster = data[(data.Cluster == i)][timepoints].sample(samples)
        cluster.loc[:, "0":"20.5"].T.plot(legend=None, color = [colors(i)], ax=ax[i], alpha = 0.1, linewidth=2)
        
manylines(data)

It's almost like an art piece! We can clearly see patterns in each cluster:

  • Cluster 1 seems to be even, never changing much at all.

  • Cluster 2 shoots upwards

  • Cluster 3 jumps up and down wildly, it seems to have captured the more erratic points.

  • Cluster 4 drops down

We can of course plot the KDEs and the means again:

In [33]:
manykde(data, timepoints);
In [34]:
plotmeans(data);

Yep, these confirm it: nice, clean separation!

It looks like cluster1 (orange) and cluster 3 (red) are the genes most affected by starvation.

Thanks to the clusters_to_plot parameter we built in, we can look at KDEs of just those clusters

In [35]:
manykde(data, timepoints, clusters_to_plot = [1, 3]);

Let's not forget that we have the names of all of these genes, too: Let's look at a random sample of some of the named genes from cluster 1, the cluster that's highly expressed during starving conditions.

In [36]:
data[(data['Cluster'] == 1) & data['GENE_NAME'].notnull()].sample(5)
Out[36]:
ORF GENE_NAME 0 9.5 11.5 13.5 15.5 18.5 20.5 Cluster
1664 YBR256C RIB5 0.240748 -0.190134 0.703130 0.537499 0.242312 1.598754 0.604642 1
5201 YGR142W BTN2 -0.275107 -0.318785 0.205271 0.187266 0.823660 1.922950 2.039855 1
2263 YIL101C XBP1 -0.984318 -0.785495 0.087463 0.425926 -0.436864 2.097611 2.203241 1
5257 YHR016C YSC84 -0.060261 -0.113101 0.061466 0.256500 0.996044 1.781079 0.810355 1
3998 YLR178C TFS1 -0.206719 -0.403846 -0.063984 0.794773 1.329846 3.103137 2.015230 1

XBP1 is a cool name. What's that?

We can look it up on yeastgenome.org and it tells us that KGD1 is

not expressed during log phase of growth, but induced by stress or starvation during mitosis, and late in meiosis; represses 15% of all yeast genes as cells transition to quiescence; important for maintaining G1 arrest and for longevity of quiescent cells;

Hey, there we go! A protein that gets expressed when yeast is starving, and it's job is to shut down 15% of all genes and maintain the longevity of cells that are "quiescent" (microbiologist-speak for hibernating).

Our algorithim has prooved useful, and using just some microarray data has helped us sort 6000+ genes into categories, correctly identifying genes that are important to cellular starvation. We've done some real data science, albeit on a data set that's been around for a while on one of the most-studied organisms on the entire planet.

That's been the main part of this tutorial, but I encourage you to stick around for a bonus segment where I will make some more sophisticated use of the MatPlotLib library to generate an animation!

Advanced Plotting: Animation

Let's take the full dataset and make a movie out of the clustering process!

We can use then full-sized dataset, and find 6 clusters inside of it.

Set K, get sample, make centers, assign clusters

In [37]:
k = 6
sample = data.sample(k)
centerslist = sample[timepoints].values
data['Cluster'] = data[timepoints].apply(p2c, axis=1, raw=True)

We're going to use matplotlib's "Func Animation" module and we're going to animate the movement of the KDE lines for the final timepoint.

This is going to be a lot of code, but to be honest most of it is formatting and comments

In [42]:
# This part is nothing new
def clustering_animation(frames, interval = 250):
    
    # We're going to declare a few variables
    labels = ["Cluster {}".format(i) for i in range(k)]
    x = data.groupby("Cluster").size()    
    xdata = [float(t) for t in timepoints]
    
    # Then make the master figure and remove the box and axis from it
    fig = plt.figure(figsize=(10, 15))
    plt.box('off')
    plt.axis('off')
    
    # The title here is long, so we're splitting it onto 2 lines. 
    # Python will treat everything that's inside of the same bracket 
    # as being on the same line.
    # y = 1.05 moves the title up to 105% of it's normal height. 
    # This will give us room for a legend underneath of the title.
    plt.title("Animation of K-means Clustering on a 7-dimensional Microarray Dataset",
              y=1.05, fontsize='x-large')  
    
    # For the first graph we'll do the means of the clusters
    
    # We'll create an axis, and call it "ax1"
    ax1 = fig.add_subplot(3,1,1)
    
    # Remove those pesky spines and set the limits of the y-axis for consistency
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.set_ylim([-3,3])

    # Add some titles and labels
    ax1.set_title("Mean Values of Yeast Expression Ratios")
    ax1.set_xlabel("Time since innoculation\n  /h")
    ax1.set_ylabel("Expression\nRatio", rotation = 0, labelpad = 25)
    
    # And now we are going to render the lines and add those lines to a list
    # We're doing this so we can directly change the line data later, 
    # without having to redraw everything each time
    lines = []
    for i in range(k):
        cluster = data[(data.Cluster == i)][timepoints]
        mean = cluster.mean()
        
        #lobj means "line object", the trailing comma is really important here 
        # for reasons I won't get into
        lobj, = ax1.plot(xdata, mean, color = colors(i), label=labels[i], lw=4)
        # adding the line object to the list
        lines.append(lobj)
    
    # The eagle-eyed among you will notice that we didn't draw the 
    # standard deviations yet. That's because the "fill_between" 
    # is not easy to change on-the-fly, so we're going to redraw it 
    # every frame instead of just editing the data like we do for the lines
    
    
    # The second plot will be the KDEs
    # Once again we're not bothering to render a first set of lines, 
    # The KDE lines are too much effort to edit on-the-fly
    
    # First is formatting. I absolutely could make this a function, and 
    # not re-write this but that might make it harder to tweak and tune things.
    # and I want *you* to have the ability to fiddle with these settings yourself
    ax2 = fig.add_subplot(3,1,2)
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax2.set_ylim([0,2])
    ax2.set_xlim([-4, 4])
    ax2.set_title("KDE Plot of Expression Ratios at the 20.5 hour Timepoint")
    ax2.set_xlabel("Expression\nRatio")

    # The third plot will be a bar graph of the number of genes in each cluster
    ax3 = fig.add_subplot(3,1,3)
    ax3.spines['top'].set_visible(False)
    ax3.spines['right'].set_visible(False)
    ax3.set_title("Number of Points in Each Cluster")
    ax3.set_ylabel("Number", rotation = 0, labelpad = 30)
    ax3.set_xlabel("Cluster")
    
    # Here we create a bunch of rectangles which we will add to a list to update later.
    rectangles = list(bar([i for i in range(k)], data.groupby("Cluster").size(), 
                          color = [colors(i) for i in range(k)]))
    
    
    # One final bit of formatting to create and modify a legend 
    leg = fig.legend(lines,                       # We'll use the list of lines as our representative image 
                     [i for i in range(k)],       # And title the lines as according to thier cluster number
                     title = 'Cluster:',          # The legend will have its own title
                     fontsize='large',            # We may as well make large text
                     ncol=k,                      # The legend has k columns (one row)
                     loc = 'center',              # The legend will be positioned based on its center
                     bbox_to_anchor=(0.57, 0.96), # Here we fiddle with moving the legend around
                     frameon = False)             # Finally we're removeing the box around the legend
    
    # Here we do a little more fiddling with the legend's title
    leg._legend_box.align = "left"                # Left align the title text
    leg.get_title().set_position((-55, -19))      # Move the text down to be beside the row of entries
    leg.get_title().set_fontsize('large')         # And make the fontsize consistent with the entries
    
    # Finally we do that tight_layout thing, with an extra request for more "height padding" between subplots
    plt.tight_layout(h_pad=3)
    
    
    # The next piece is the code for updating each frame. 
    # "i" is just the framecount, we won't need it though.
    def animate(i):
        # We have to tell python that "centerslist" should be the list we declared outside of the function,
        # Otherwise it will be confused when we update the variable later
        global centerslist
    
        # We're going to update the rectangles first, because they are easy
        # n is the number of datapoints in each cluster
        n = data.groupby("Cluster").size().values
        
        # we pair up rectangles and n-values, and then assign the rectangle heights to the n-value
        for rectangle, n in zip(rectangles, n):
            rectangle.set_height(n)

        # These two lines delete the existing standard deviation fills and the KDE plot lines
        ax1.collections = []
        ax2.lines = []
        
        # Now we draw the line graph and KDE plot
        # I'm being a bit sneaky here an piggybacking the kde-updates onto the same loop as the line-update
        for j, line in enumerate(lines):
            # Grab the mean, standard deviation for the cluster
            cluster = data[(data.Cluster == j)][timepoints]
            mean = cluster.mean()
            std = cluster.std()
            
            # Then simply set the line data to the new data
            line.set_data(xdata, mean)
            
            # And draw that fill-between
            ax1.fill_between(xdata, mean-std, mean+std, facecolor = colors(j), alpha=0.1)
            
            # For the KDE plot we are going to get only the last timepoint
            # In python asking a list for a negative index counts from the tail-end
            s = cluster[timepoints[-1]]
            
            # Finally we plot the KDE curve for this loop
            s.plot.kde(label=labels[j], color=colors(j), ax = ax2, linewidth = 3)
            
        # All that's left is to step through the clustering algorithim once
        centerslist = c2p(data, timepoints)
        data['Cluster'] = data[timepoints].apply(p2c, axis=1, raw=True)
        
        # And to return the updated rectangles and lines   
        return rectangles + lines
    
    
    # It's time to use the "FuncAnimation" function to build our movie
    anim = animation.FuncAnimation(fig,               # "fig" is the fig we just built
                                   animate,           # Animate is the function that updates the data
                                   frames=frames,     # Frames is however many frames we put in the function call above
                                   interval=interval, # Interval is the time to wait between frames (in milliseconds)
                                   blit=True);        # blit will save a bit of time by only drawing things that change
   
    plt.close() # This closes down the figure
    return anim # and then we return the finished animation!

Phew! That was a LOT of text.

Before you get freaked out, remember that most of it was comments, and then most of the actual code was formatting.

Anyways, it's time to wind down and watch a movie!

Note: If you are running this at home it may take a minute or so to render, if you get an error saying you have no moviewriter, your FFMPEG install is not set up correctly

Note2: if for any reason this does not show up properly for you, please find an mp4 file on the GitHub

In [39]:
clustering_animation(60)
Out[39]:

It looks at a glance like creating more clusters hasn't accomplished anything special, but let's take a look at that art-piece plot

In [40]:
manylines(data)

Despite the muddling of clusters 0, 1, 3, and 5 in the KDE plot, there are clear patterns to be seen in each cluster:

  • Cluster 0 slowly increases
  • Cluster 1 slowly decreases
  • Cluster 2 rapidly increases, and it's a lot more dispersed
  • Cluster 3 is all over the place
  • Cluster 4 sharply decreases
  • Cluster 5 is tame and barely moves at all.

And that's it (for now), but if you're using this notebook on your own computer feel free to explore some of the genes in each cluster, maybe see why some seem to vary a lot while others don't by looking up different genes from each cluster in the yeast genome database.

Next time:

One of the issues with this algorithm is that there's no sense of subtlety. Every point is either absolutely in or absolutely not in a cluster, but that's not always the best way to categorize things. Next time we'll investigate some ways of making these clusters "soft", as well as some statistical notions of clustering and ways to implement those in code. Stay Tuned!

Also in the pipeline:

It was slightly annoying to have to decide on a K-value to sort with every time, wasn't it? Clustering algorithims like K-means are often called "Flat", as opposed to the "Hierarchical" clustering algorithims that create tiered trees such as this one: Hierarchy

Where the data comes out as a tree, and you can look at each level of clustering. Look forward to a future installment about this kind of clustering algorithm.

Note that those tend to take much, much longer to run, though, and people usually avoid them if they think that simply guessing a K-value would work about as well.