import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
import matplotlib.ticker as mticker
[docs]def figsize(scale, nplots=1):
"""Define the figure size."""
fig_width_pt = 390.0 # Get this from LaTeX using \the\textwidth
inches_per_pt = 1.0 / 72.27 # Convert pt to inch
# Aesthetic ratio (you could change this)
golden_mean = (np.sqrt(5.0) - 1.0) / 2.0
fig_width = fig_width_pt * inches_per_pt * scale # width in inches
fig_height = nplots * fig_width * golden_mean # height in inches
fig_size = [fig_width, fig_height]
return fig_size
pgf_with_latex = { # setup matplotlib to use latex for output
"pgf.texsystem": "pdflatex", # change this if using xetex or lautex
"text.usetex": True, # use LaTeX to write all text
"font.family": "DejaVu Sans",
# blank entries should cause plots to inherit fonts from the document
"font.serif": [],
"font.sans-serif": [],
"font.monospace": [],
"axes.labelsize": 10, # LaTeX default is 10pt font.
"font.size": 10,
"legend.fontsize": 8, # Make the legend/label fonts a little smaller
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"figure.figsize": figsize(1.0), # default fig size of 0.9 textwidth
"pgf.preamble": [
# use utf8 fonts becasue your computer can handle it :)
r"\usepackage[utf8x]{inputenc}",
# plots will be generated using this preamble
r"\usepackage[T1]{fontenc}",
]
}
# Try using latex font
try:
mpl.rcParams.update(pgf_with_latex)
except Exception:
mpl.rcParams.update(mpl.rcParamsDefault)
pass
[docs]def newfig(width, nplots=1):
"""Create a new figure."""
fig = plt.figure(figsize=figsize(width, nplots), dpi=600)
ax = fig.add_subplot(111)
return fig, ax
[docs]def savefig(filename, crop=True):
"""Save the figure as a pdf file."""
if crop is True:
plt.savefig('{}.pdf'.format(filename),
bbox_inches='tight', pad_inches=0)
else:
plt.savefig('{}.pdf'.format(filename),
bbox_inches='tight', pad_inches=0.03)
[docs]class MathTextSciFormatter(mticker.Formatter):
"""Format the axis of the figure."""
def __init__(self, fmt="%1.2e"):
self.fmt = fmt
def __call__(self, x, pos=None):
s = self.fmt % x
decimal_point = '.'
positive_sign = '+'
tup = s.split('e')
significand = tup[0].rstrip(decimal_point)
sign = tup[1][0].replace(positive_sign, '')
exponent = tup[1][1:].lstrip('0')
if exponent:
exponent = '10^{%s%s}' % (sign, exponent)
if significand and exponent:
s = r'%s{\times}%s' % (significand, exponent)
else:
s = r'%s%s' % (significand, exponent)
return "${}$".format(s)