Skip to content

Instantly share code, notes, and snippets.

@ken333135
Created May 8, 2019 07:57
Show Gist options
  • Save ken333135/09f8793fff5a6df28558b17e516f91ab to your computer and use it in GitHub Desktop.
Save ken333135/09f8793fff5a6df28558b17e516f91ab to your computer and use it in GitHub Desktop.
Wrapper Function to create Sankey Diagram from DataFrame
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
# maximum of 6 value cols -> 6 colors
colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
labelList = []
colorNumList = []
for catCol in cat_cols:
labelListTemp = list(set(df[catCol].values))
colorNumList.append(len(labelListTemp))
labelList = labelList + labelListTemp
# remove duplicates from labelList
labelList = list(dict.fromkeys(labelList))
# define colors based on number of levels
colorList = []
for idx, colorNum in enumerate(colorNumList):
colorList = colorList + [colorPalette[idx]]*colorNum
# transform df into a source-target pair
for i in range(len(cat_cols)-1):
if i==0:
sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
sourceTargetDf.columns = ['source','target','count']
else:
tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
tempDf.columns = ['source','target','count']
sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
# add index for source-target pair
sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
# creating the sankey diagram
data = dict(
type='sankey',
node = dict(
pad = 15,
thickness = 20,
line = dict(
color = "black",
width = 0.5
),
label = labelList,
color = colorList
),
link = dict(
source = sourceTargetDf['sourceID'],
target = sourceTargetDf['targetID'],
value = sourceTargetDf['count']
)
)
layout = dict(
title = title,
font = dict(
size = 10
)
)
fig = dict(data=[data], layout=layout)
return fig
@paulrougieux
Copy link

Hi ken333135, is it possible to re-use your code and redistribute it under an MIT license?

@ken333135
Copy link
Author

Yea sure. Feel free!

@paulrougieux
Copy link

Thanks I use it for a diagnostic plot in a modelling pipeline we're building.

@ken333135
Copy link
Author

ken333135 commented Sep 26, 2019 via email

@jpsteege
Copy link

jpsteege commented Nov 5, 2020

Nice piece of code and very usable Medium article!
For my own use, I added a code snippet that creates the needed number of colors based on the number of cat_cols using Seaborn color palettes (here: 'Spectral'). For more palette options, check https://seaborn.pydata.org/tutorial/color_palettes.html.

import pandas as pd
import seaborn as sns

def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # Source: https://medium.com/kenlok/how-to-create-sankey-diagrams-from-dataframes-in-python-e221c1b4d6b0
    # Create colors based on the number of categorical columns
    colorPalette = sns.color_palette("Spectral", len(cat_cols)).as_hex()
    labelList = []
    colorNumList = []

...

@rrosasl
Copy link

rrosasl commented Dec 2, 2020

Hi Ken,

THANK YOU SO MUCH FOR THIS ONE
I have used this Sankey so many times :D
Most recently here https://rrosasl.medium.com/ranked-choice-voting-with-google-forms-and-python-c471ea568a60

There I also have some useful code for converting many DataFrames into the necessary format for the Sankey Diagram :)

@KeshavSharmaWeb
Copy link

Thanks for the code!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment