7 ways to label a cluster plot in Python

This tutorial shows you 7 different ways to label a scatter plot with different groups (or clusters) of data points. I made the plots using the Python packages matplotlib and seaborn, but you could reproduce them in any software. These labeling methods are useful to represent the results of clustering algorithms, such as k-means clustering, or when your data is divided up into groups that tend to cluster together.

Here's a sneak peek of some of the plots:

cluster_subplots.png

You can access the Juypter notebook I used to create the plots here. I also embedded the code below.

First, we need to import a few libraries and define some basic formatting:

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

#set font size of labels on matplotlib plots
plt.rc('font', size=16)

#set style of plots
sns.set_style('white')

#define a custom palette
customPalette = ['#630C3A', '#39C8C6', '#D3500C', '#FFB139']
sns.set_palette(customPalette)
sns.palplot(customPalette)

CREATE LABELED GROUPS OF DATA

Next, we need to generate some data to plot. I defined four groups (A, B, C, and D) and specified their center points. For each label, I sampled nx2 data points from a gaussian distribution centered at the mean of the group and with a standard deviation of 0.5.

To make these plots, each datapoint needs to be assigned a label. If your data isn't labeled, you can use a clustering algorithm to create artificial groups.

#number of points per group
n = 50

#define group labels and their centers
groups = {'A': (2,2),
          'B': (3,4),
          'C': (4,4),
          'D': (4,1)}

#create labeled x and y data
data = pd.DataFrame(index=range(n*len(groups)), columns=['x','y','label'])
for i, group in enumerate(groups.keys()):
    #randomly select n datapoints from a gaussian distrbution
    data.loc[i*n:((i+1)*n)-1,['x','y']] = np.random.normal(groups[group], 
                                                           [0.5,0.5], 
                                                           [n,2])
    #add group labels
    data.loc[i*n:((i+1)*n)-1,['label']] = group

data.head()
example_data.png

STYLE 1: STANDARD LEGEND

Seaborn makes it incredibly easy to generate a nice looking labeled scatter plot. This style works well if your data points are labeled, but don't really form clusters, or if your labels are long.

#plot data with seaborn
facet = sns.lmplot(data=data, x='x', y='y', hue='label', 
                   fit_reg=False, legend=True, legend_out=True)
cluster_plot_standard_legend.png

STYLE 2: COLOR-CODED LEGEND

This is a slightly fancier version of style 1 where the text labels in the legend are also color-coded. I like using this option when I have longer labels. When I'm going for a minimal look, I'll drop the colored bullet points in the legend and only keep the colored text.

#plot data with seaborn (don't add a legend yet)
facet = sns.lmplot(data=data, x='x', y='y', hue='label', 
                   fit_reg=False, legend=False)

#add a legend
leg = facet.ax.legend(bbox_to_anchor=[1, 0.75],
                         title="label", fancybox=True)
#change colors of labels
for i, text in enumerate(leg.get_texts()):
    plt.setp(text, color = customPalette[i])
    
cluster_plot_colored_legend.png

STYLE 3: COLOR-CODED TITLE

This option can work really well in some contexts, but poorly in others. It probably isn't a good option if you have a lot of group labels or the group labels are very long. However, if you have only 2 or 3 labels, it can make for a clean and stylish option. I would use this type of labeling in a presentation or in a blog post, but I probably wouldn't use in more formal contexts like an academic paper. 

#plot data with seaborn
facet = sns.lmplot(data=data, x='x', y='y', hue='label', 
                   fit_reg=False, legend=False)

#define padding -- higher numbers will move title rightward
pad = 4.5

#define separation between cluster labels
sep = 0.3

#define y position of title
y = 5.6

#add beginning of title in black
facet.ax.text(pad, y, 'Distributions of points in clusters:', 
              ha='right', va='bottom', color='black')

#add color-coded cluster labels
for i, label in enumerate(groups.keys()):
    text = facet.ax.text(pad+((i+1)*sep), y, label, 
                         ha='right', va='bottom',
                         color=customPalette[i])
cluster_plot_colored_title.png

STYLE 4: LABELS NEXT TO CLUSTERS

This is my favorite style and the labeling scheme I use most often. I generally like to place labels next to the data instead of in a legend. The only draw back of this labeling scheme is that you need to hard code where you want the labels to be positioned. 

#define labels and where they should go
labels = {'A': (1.25,1),
          'B': (2.25,4.5),
          'C': (4.75,3.5),
          'D': (4.75,1.5)}

#create a new figure
plt.figure(figsize=(5,5))

#loop through labels and plot each cluster
for i, label in enumerate(groups.keys()):

    #add data points 
    plt.scatter(x=data.loc[data['label']==label, 'x'], 
                y=data.loc[data['label']==label,'y'], 
                color=customPalette[i], 
                alpha=0.7)
    
    #add label
    plt.annotate(label, 
                 labels[label],
                 horizontalalignment='center',
                 verticalalignment='center',
                 size=20, weight='bold',
                 color=customPalette[i]) 
cluster_plot_adjacent_labels.png

STYLE 5: LABELS CENTERED ON CLUSTER MEANS

This style is advantageous if you care more about where the cluster means are than the locations of the individual points. I made the points more transparent to improve the visibility of the labels.

#create a new figure
plt.figure(figsize=(5,5))

#loop through labels and plot each cluster
for i, label in enumerate(groups.keys()):

    #add data points 
    plt.scatter(x=data.loc[data['label']==label, 'x'], 
                y=data.loc[data['label']==label,'y'], 
                color=customPalette[i], 
                alpha=0.20)
    
    #add label
    plt.annotate(label, 
                 data.loc[data['label']==label,['x','y']].mean(),
                 horizontalalignment='center',
                 verticalalignment='center',
                 size=20, weight='bold',
                 color=customPalette[i]) 
cluster_plot_labeled_means1.png

STYLE 6: LABELS CENTERED ON CLUSTER MEANS (2)

This style is similar to style 5, but relies on a different way to improve label visibility. Here, the background of the labels are color-coded and the text is white.

#create a new figure
plt.figure(figsize=(5,5))

#loop through labels and plot each cluster
for i, label in enumerate(groups.keys()):

    #add data points 
    plt.scatter(x=data.loc[data['label']==label, 'x'], 
                y=data.loc[data['label']==label,'y'], 
                color=customPalette[i], 
                alpha=1)
    
    #add label
    plt.annotate(label, 
                 data.loc[data['label']==label,['x','y']].mean(),
                 horizontalalignment='center',
                 verticalalignment='center',
                 size=20, weight='bold',
                 color='white',
                 backgroundcolor=customPalette[i]) 
cluster_plot_labeled_means2.png

STYLE 7: TEXT MARKERS

This style is a little bit odd, but it can be effective in some situations. This type of labeling scheme may be useful when there are few data points and the labels are very short.

#create a new figure and set the x and y limits
fig, axes = plt.subplots(figsize=(5,5))
axes.set_xlim(0.5,5.5)
axes.set_ylim(-0.5,5.5)

#loop through labels and plot each cluster
for i, label in enumerate(groups.keys()):

    #loop through data points and plot each point 
    for l, row in data.loc[data['label']==label,:].iterrows():
    
        #add the data point as text
        plt.annotate(row['label'], 
                     (row['x'], row['y']),
                     horizontalalignment='center',
                     verticalalignment='center',
                     size=11,
                     color=customPalette[i]) 
cluster_plot_text_markers.png

Detecting ‘bursts’ in time series data with Kleinberg’s burst detection algorithm

Awhile ago, I was watching an online course about data visualization and one of the analyses that stuck out to me was called burst detection. Burst detection is a way of identifying periods of time in which some event is unusually popular. In other words, you can use it to identify fads, or “bursts,” of events over time.

I realized that I could use burst detection to answer a long-standing curiosity of mine: what does a timeline of fMRI trends look like? What kind of fads have popped up in the short history of fMRI and what topics are currently popular? I have an intuitive sense of what topics were popular throughout fMRI's short history, but I like the idea of using burst detection to identify trends in the fMRI literature in a data-driven way.

The video that introduced me to the idea of burst detection used software that I didn’t have access to, but I found the paper that the analysis is based on, titled “Bursty and Hierarchical Structure in Streams”, by Kleinberg (2002). I implemented the bursting algorithm in Python, which you can access on Pypi or Github. The functions that I wrote apply to the algorithms described in the second half of the paper, which involve detecting bursts in discrete bundles of events. There are already packages in Python and R that implement the algorithms in the first half of the paper, which involve detecting bursts in continuous streams of events.

In this blog post, I will describe the rationale behind burst detection and describe how to implement it. In subsequent blog posts, I will apply the algorithm to real data. I’ve already found “bursts” in the fMRI literature and I’d like to also detect bursts in my Googling history and in news archives.

RATIONALE OF BURST DETECTION

Kleinberg’s burst detection algorithm identifies time periods in which a target event is uncharacteristically frequent, or “bursty.”  You can use burst detection to detect bursts in a continuous stream of events (like receiving emails) or in discrete batches of events (like poster titles submitted to an annual conference). I focused on detecting bursts in discrete batches of events, since scientific articles are often published in batches in each journal issue.

Here’s the basic idea: a set of events, consisting of both target and non-target events, is observed at each time point t. If we use the poster title example, target events may consist of poster titles that include the word connectivity and non-target events may consist of all other poster titles (that is, all the poster titles that do not include the word connectivity). The total number of events at each time point is denoted by d and the number of target events is denoted by r. The proportion p of target events at each time point is equal to r/d

burst-detection-targets.png

Burst detection assumes that there are multiple states (or modes) that correspond to different probabilities of target events. Some states have high target probabilities, some states have very low target probabilities, and others have moderate target probabilities. If we assume that there are only two possible states, then we can think of the state with the lower probability as the baseline state and the state with the higher probability as the “bursty” state. The baseline probability is equal to the overall proportion of target events:

baseline_prob_eq.png

where R is the sum of target events at each time point and D is the sum of total events at each time point.

The bursty state probability is equal to the baseline probability multiplied by some constant s.You choose what to make s. If s is large, the probability of target events needs to be high to enter a bursty state.

bursty_prob_eq.png

When you are in a given state, you expect that, on average, the target events will occur with the probability associated with that state. Sometimes the proportion of target events will be higher than expected and sometimes it will be lower than expected due to random noise. The purpose of burst detection is to predict what state the system is in based on the sequence of observed proportions. In other words, given the observed proportions of target events in each batch of events, the burst detection algorithm will determine when the system was likely in a baseline state and when it was likely in a bursty state.

Determining which state the system is in at any given time depends on two things:

1. The goodness of fit between the observed proportion and the expected probability of each state. The closer the observed proportion is to the expected probability of a state, the more likely the system is in that state. Goodness of fit is denoted by sigma, which is defined as:

sigma_eq.png

where i corresponds to the state (in a two-state system, i=0 corresponds to the baseline state and i=1 corresponds to the bursty state).

2. The difficulty of transitioning from the previous state to the next state. There’s a cost associated with entering a higher state, but no cost associated with staying in the same state or returning to  a lower state. The transition cost, denoted by tau, therefore equals zero when transitioning to a lower state or staying in the same state. When entering a higher state, the transition cost is defined as:

tau_eq.png

where n is the number of time points and gamma is the difficulty of transitioning into higher states. You can choose the value of gammaHigher values make it harder to transition into a more bursty state.

The total cost of transitioning from one state to another is equal to the sum of the two functions above. With the cost function in hand, we can find the optimal state sequence, q.  The optimal state sequence is the sequence of states that minimizes the total cost or, in other words, the sequence that best explains the observed proportions. We find q with the Viterbi algorithm. The basic idea is simple: first, we calculate the cost of being in each state at t=1 and we choose the state with the minimum cost; then we calculate the cost of transitioning from our current state in t=1 to each possible state at t=2, and again we choose the state with the minimum cost. We repeat these steps for all time points to get a state sequence that minimizes the cost function.

The state sequence tells you when the system was in a heightened, or bursty, state. We can repeat these steps for different target events (for example, different words in poster titles) to build a timeline of what events were popular over time.

The strength, or weight, of a burst (that begins at time point t1 and ends at time point t2) can be estimated with the following function:

weight_eq.png

This equation simply tells us how much the fit cost is reduced when we are in a bursty state vs. the baseline state during the burst period. The more the fit cost is reduced, the stronger the burst and the greater the weight.

IMPLEMENTATION WITH SIMULATED DATA

I implemented the burst detection algorithm in Python and created a time series with artificial bursts to test the code. The time series consisted of 1000 time points and bursts were added from t=200 to t=399 and t=700 to t=799. Here’s what the raw time course looked like:

raw_timeseries.jpg

Setting s to 2 and gamma to 1, the algorithm identified one burst from t=701 to t=800 and 32 small bursts between t=200 and t=395. What does this tell us? The burst detection algorithm can easily identify periods in which the proportion of target events is much higher than usual, but it has a harder time identifying weaker bursts, especially in the presence of noise. I repeated the analysis using different values for s and gamma to get a sense of how these values affect the burst detection. Here, the bursts from each analysis (represented with blue bars) are plotted on the same timeline:

Screen Shot 2016-11-05 at 4.46.29 PM.png


You can think of s as the distance between states. When s is small, the difference between the states’ expected probabilities is also small. When we increase s while holding gamma constant (as shown in the first four timelines), we get shorter bursts. Essentially, we’re breaking up larger bursts into smaller bursts since we’re increasing the threshold that the observed proportions need to meet in order to be considered in a burst. Since the time course is so noisy, some timepoints in the artificial burst periods do not meet that threshold and fewer and fewer time points meet the threshold as s increases.

Gamma determines how difficult it is to enter a higher state. Since there is no cost associated with staying in the same state or returning to a lower state, changing gamma should only affect the beginnings of bursts and not their endings. You can see that as gamma increases, we get fewer and shorter bursts since, again, we are making it more difficult to enter a bursty state. It’s not obvious from the timeline, but if you look at the start and end points of the burst that survived all of the gamma settings, you find that the burst ends at t=281 regardless of gamma. However, it begins at t=274 when gamma is 0.5, at t=279 when gamma is 1, and t=280 when gamma is 2 or 3.

As these plots illustrate, the burst detection algorithm is highly sensitive to noise. To reduce the effects of noise, we can temporally smooth the time course. Here’s what the bursts look like when we use a smoothing window with a width of 5 time points:

Screen Shot 2016-11-05 at 4.46.21 PM.png

We get fewer, longer bursts since the proportions at each time point are less noisy. For this data, s=1.5 and gamma=1 identified both bursts.

Hopefully it’s obvious that the results of burst detection depend heavily on the parameters you choose. It may not be the best method to use if you care about the specific start and end points of bursts (or the number of bursts), but it’s a useful method if you’re interested in general trends in a large dataset over time.

I really recommend reading Kleinberg’s paper for a more detailed (and sophisticated) explanation of burst detection. If you’re interested in seeing the code I wrote to generate the figures in this post, you can check out my iPython notebook. The burst_detection package is my first, so please let me know if you run into any problems or if you notice any errors.

As I mentioned at the beginning of the post, I already applied the burst detection algorithm to the fMRI literature to identify trends in fMRI over the last 20 years. I’ll try to post that analysis soon. I have a few more project ideas after that, including finding trends in my Google search history, finding trends in rap lyrics, and finding trends in news articles. Let me know if you know of any interesting datasets that are amenable to burst detection or if you end up using the burst_detection package yourself.

Classy navy and coral poster template

I made this poster for the Cognitive Neuroscience Society conference in 2015. This poster summarizes a follow-up analysis of my first-year fMRI project. We estimated the certainty and expectedness associated with the stimuli used in the experiment and investigated which brain areas tracked hypothesis certainty and evidence expectancy over time. 

As with most posters, I started with a color scheme that I pulled from Pinterest. The color scheme I chose for this poster consisted of navy, coral, slate gray, and tan (which I ended up dropping):

color_scheme.png

For this poster, I wanted to go for a cleaner, more refined look. I also wanted to use a different section format than what I usually use. Sometimes I play around with a few possible designs before deciding on a format and style. I found the mini-poster designs that I made before making this poster. Here are a few of them:

You can see that they're all pretty similar, but they vary a little bit.

Anyway, you can click the buttons below to download the template in Powerpoint or Keynote:

I'd love to see any posters you create with this template! Feel free to post them in the comments below.

How to assign different colors to different conditions in Excel graphs

This post was originally published on May 12, 2014.

Excel is a great program to use if you don’t require complicated graphs and if you use their default formatting (which you shouldn’t!). However, if you need more advanced formatting, making graphs in Excel can become complicated. It’s not impossible though — there are ways to manipulate Excel to create beautiful, custom charts. 

In this tutorial, I describe how to conditionally format graphs in Excel. Conditional formatting is useful when you have two or more conditions that you want to format differently. For example, if you have data from a patient group and a control group and you want to display their data in different colors, sizes, or shapes. Excel does let you format data points individually, but applying the same format to every data point in each condition can quickly become labor intensive. Here's an example where two groups of participants are labeled with different colors:

formatted_scatter_plot.png

STEP 1: ORGANIZE YOUR DATA

To begin, organize all your data in columns. The first column should contain dummy codes for your different conditions. For example, if you have three conditions, you’ll use the numbers 1,2, and 3 to represent your conditions. The remaining columns should contain your variables that you plan to plot — you should have two columns if you want to make a scatter chart and one column if you are making a bar chart or something similar. In this example, I have two groups: a control group, which I dummy coded with 1, and a patient group, which I dummy coded with 2. Each subject has a response time that is listed in column B:

organized_data.png

 

STEP 2: SEPARATE DATA ACCORDING TO CONDITION TYPE

This step includes the meat of the tutorial. First, create new columns for each condition type and each variable (if you have 2 conditions and 1 variable, you’ll need 2 new columns and if you have 3 conditions and 2 variables, you’ll need 6 new columns). In each column, type an IF statement with the following parameters:

=IF(condition = dummy code, variable, NA() )

where "condition" is the cell containing your dummy code, "dummy code" is the actual number of your dummy coding (this number will change in each column), and "variable" is the cell containing your data. So in our example, we’ll use the following IF statements:

In the first new column:    =IF(A6=1,B6,NA())

In the second new column:    =IF(A6=2,B6,NA())

Finally, apply the formatting to all your data by highlighting the cells containing your IF statements and dragging the blue box in the lower righthand corner all the way down to the end of your data.

separated_data.png

STEP 3: CREATE A GRAPH

At this point, you can create whatever type of graph you need — this conditional formatting technique can be used to create a wide variety of graphs. In this example, I created a simple bar chart to visualize subjects’ response times according to condition type.

To create a bar chart, click on the charts tab in the Excel ribbon. Highlight the data in the new columns you just created (the columns with the #N/A values), click on the Bar icon, and select clustered bar chart. Now when you click on a datapoint in the graph, it will highlight all of the other datapoints that belong to the same condition. You can format the datapoints in each condition separately. 

unformatted_bar_graph.png

To make the spacing of the bars uniform, right click on the data series and select format data series. Change the overlap to 100% and the gap width to 100%. To finish formatting, I removed the gridlines, added a title and and x axis label, changed the typeface of the axes and labels to Avenir, changed the color of the bars to a neutral gray for the controls and a teal for the patients, and removed all shadows and other special effects. I also removed the y-axis labels since the subject number does not matter in this case. Here’s the finished graph:

formatted_graph.png

You can use the same conditional formatting technique to create scatter plots too. Here's an example of how to set up the data:

scatterplot_data.png

And here's what the formatted scatter plot looks like:

formatted_scatter_plot.png