import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.ticker import FuncFormatter
import seaborn as sns
set(style="ticks", font_scale=1.5) sns.
Fancy subplot grid
GridSpec is your friend
Introduction
With GridSpec you can create any combination of panels
The code
# figsize accepts only inches.
= plt.figure(1, figsize=(8, 6))
fig = gridspec.GridSpec(3, 2, width_ratios=[1,0.5], height_ratios=[1,0.7,0.3])
gs =0.16, right=0.86,top=0.88, bottom=0.13, hspace=0.05, wspace=0.05)
gs.update(left
###########
# subplot a
###########
= plt.subplot(gs[0, :])
ax0
= lambda x: 0.5 * (np.sign(x) + 1)
heaviside = np.arange(0, 10.01, 0.01)
x - 2), color='purple', lw=3)
ax0.plot(x, heaviside(x 2.5, 1.1, r"$\longleftarrow$ heaviside")
ax0.text(
# y ticks as a percentage
-0.5, 2.0, 0.5))
ax0.set_yticks(np.arange(def to_percent(y, position):
# Ignore the passed in position. This has the effect of scaling the default
# tick locations.
= "{:+.0f}".format(y * 100) # str(100 * y)
s # The percent symbol needs escaping in latex
if matplotlib.rcParams['text.usetex'] is True:
return s + r'$\%$'
else:
return s + '%'
# Create the formatter using the function to_percent. This multiplies all the
# default labels by 100, making them all percentages
= FuncFormatter(to_percent)
formatter # Set the formatter
ax0.yaxis.set_major_formatter(formatter)"heaviside, percentage")
ax0.set_ylabel(
# x ticks on top
min(), x.max(), -0.5, 1.5])
ax0.axis([x.
ax0.xaxis.tick_top()r"x labels on top")
ax0.set_xlabel("top")
ax0.xaxis.set_label_position(
# transAxes makes position relative to axes
0.97, 0.97, r"a", transform=ax0.transAxes,
ax0.text(='right', verticalalignment='top',
horizontalalignment="bold")
fontweight
# copy window with same x axis (y will be different)
= ax0.twinx()
ax0b - 5), color="green", linewidth=3)
ax0b.plot(x, np.tanh(x min(), x.max(), -1.1, 2.5])
ax0b.axis([x.5.5, 0, r"tanh $\longrightarrow$")
ax0b.text(r'tanh, offset label')
ax0b.set_ylabel(1.1, 0.70)
ax0b.yaxis.set_label_coords(
###########
# subplot b
###########
= plt.subplot(gs[1, 0])
ax10
= np.arange(-5, 5, 0.01)
x = np.exp(-x)
y ="orange", lw=3)
ax10.plot(x, y, color'log', base=2)
ax10.set_yscale(2.0 ** np.arange(-7, 7, 3))
ax10.set_yticks(1.0, 1, r"$y=e^{-x}$")
ax10.text(-5, 6, 2))
ax10.set_xticks(np.arange(-5, 6, 2), y=0.15)
ax10.set_xticklabels(np.arange(='out')
ax10.get_yaxis().set_tick_params(direction"log scale base 2", labelpad=15)
ax10.set_ylabel(0.97, 0.97, r"b", transform=ax10.transAxes,
ax10.text(='right', verticalalignment='top',
horizontalalignment="bold")
fontweight
###########
# subplot c
###########
= plt.subplot(gs[1, 1])
ax11
= np.arange(1.0, np.e ** 4, 0.01)
x = x ** (-0.8)
y ="cyan", lw=3)
ax11.plot(x, y, color2, 1, r"$y=x^{-0.8}$")
ax11.text(
ax11.loglog(x, y)"log", base=np.e)
ax11.set_xscale("log", base=np.e)
ax11.set_yscale(= np.exp(np.arange(1, 4, 1))
xt = np.pi ** (np.arange(-3, 2, 1))
yt
ax11.set_xticks(xt)=0.15)
ax11.set_xticklabels(xt, y
ax11.set_yticks(yt)
def ticks_e(y, pos): # base e
return r'$e^{:.0f}$'.format(np.log(y))
def ticks_pi(y, pos): # base pi, why not?
return r'$\pi^{%+.0f}$'%(np.log(y)/np.log(np.pi))
ax11.xaxis.set_major_formatter(FuncFormatter(ticks_e))
ax11.yaxis.set_major_formatter(FuncFormatter(ticks_pi))
ax11.yaxis.tick_right()"right")
ax11.yaxis.set_label_position("right side", labelpad=10)
ax11.set_ylabel(0.97, 0.97, r"c", transform=ax11.transAxes,
ax11.text(='right', verticalalignment='top',
horizontalalignment="bold")
fontweight
###########
# subplot d
###########
= plt.subplot(gs[2, 0])
ax20
0, 1, 0, 1])
ax20.axis([0, 1.1, 0.2))
ax20.set_xticks(np.arange("January", "February",
ax20.set_xticklabels(["March", "April",
"May", "June"],
=30, horizontalalignment="right")
rotation
ax20.set_yticks([])0.97, 0.97, r"d", transform=ax20.transAxes,
ax20.text(='right', verticalalignment='top',
horizontalalignment="bold")
fontweight
###########
# subplot e
###########
= plt.subplot(gs[2, 1])
ax21
ax21.set_xticks([])
ax21.set_yticks([])0, 1, 0, 1])
ax21.axis([0.97, 0.97, r"e", transform=ax21.transAxes,
ax21.text(='right', verticalalignment='top',
horizontalalignment="bold")
fontweight
"subplot-grid.png", dpi=300)
fig.savefig( fig