matplotlib
¶This is a quick tour of basic plotting with matplotlib. For more detail see:
You can install matplotlib on your own machine (e.g. python3 -m pip install matplotlib
).
Matplotlib is often most used in a notebook setting.
Matplotlib and numpy are pre-installed in the Google Colab notebook environment.
When working with functions, you are probably used to thinking of x
as either a variable or a number.
When working with numpy, most things are vectorized. Taking advantage of that idea, instead of x
being just a number, in this notebook we will instead use that name for vector containing all the numbers we plan to consider.
We could call it something else, like xvalues
or xvec
, but in this notebook we chose the shortest reasonable name, simply x
.
# This submodule is all you need in most cases, and plt is a common
# abbreviated name to use for it.
import matplotlib.pyplot as plt
# For a very few things you need access to the parent module
# This also lets us check the installed version.
# Uncomment the next two lines to import that.
# import matplotlib as mpl
# mpl.__version__
import numpy as np
#plt.style.available
#plt.style.use("seaborn-whitegrid")
# Let's plot y=sin(x)
# make a vector of 100 evenly spaced floats between 0 and 4pi
x = np.linspace(0,4*np.pi,100)
#y = np.array( [ np.sin(t) for t in x] ) # Don't do this, please
y = np.sin(x) # Instead, do this.
plt.figure(figsize=(8,6)) # begin a new figure ( might contain many plots )
plt.plot(x,y) # plot( vector of x vals, vector of y vals )
plt.title("Sine function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.show() # show it to me.
# Let's plot y=sin(x)
# make a vector of 100 evenly spaced floats between 0 and 4pi
x = np.linspace(0,4*np.pi,100)
#y = np.array( [ np.sin(t) for t in x] ) # Don't do this, please
y = np.sin(x) # Instead, do this.
plt.figure(figsize=(8,6)) # begin a new figure ( might contain many plots )
plt.plot(x,y,marker="x") # plot( vector of x vals, vector of y vals )
plt.title("Sine function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.show() # show it to me.
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
plt.plot(x,y)
plt.show()
x = np.linspace(start=-3,stop=3,num=500)
y = np.exp(-x*x) # f(x) = e^{-x^2}
y2 = 1.5*np.exp(-3*(x-2)**2) # g(x) = 1.5* e^{-3(x-2)^2}
y3 = 0.8*np.exp(-4*(x+1)**2) # h(x) = 0.8* e^{-4(x+1)^2}
plt.figure(figsize=(8,6))
# Multiple calls to plt.plot all end up on the same set of axes
plt.plot(x,y,color="red",label="Stewart")
plt.plot(x,y2,color="pink",label="Tina")
plt.plot(x,y3,color="#C4E434",label="29.5") # hex RRGGBB
plt.plot(x,0.5+0.7*np.sin(25*x),label="ruiner")
plt.legend()
plt.show()
x = np.linspace(start=-3,stop=3,num=50)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.plot(x,y,color="#FF0000")
plt.plot(x,y2,linewidth=3,linestyle="dashed")
plt.plot(x,y3,linestyle="",marker="o")
plt.show()
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
x2 = np.linspace(start=-2,stop=4,num=150)
y4 = 0.5*np.exp(-(x2-2)**2)*np.sin(6*x2)
plt.plot(x,y)
plt.plot(x,y2)
plt.plot(x,y3)
plt.plot(x2,y4)
plt.xlim(-1,3)
plt.ylim(0,1.6)
plt.show()
t = np.linspace(0,2*np.pi,300)
x = np.cos(t)
y = np.sin(t)
plt.plot(x,y)
plt.plot(np.sin(2*t),np.cos(t))
plt.axis("equal") # aspect ratio of plot matches aspect ratio of limits
x = np.linspace(start=-3,stop=3,num=100)
y = np.tan(x)
plt.plot(x,y)
# TODO: Fix.
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.figure(figsize=(8,6))
plt.plot(x,y,label="$e^{-x^2}$")
plt.plot(x,y2,label="$1.5e^{-3(x-2)^2}$")
plt.plot(x,y3,label="$0.8e^{-4(x+1)^2}$")
plt.legend()
plt.show()
plt.savefig("three_gaussians.png",dpi=300)
plt.savefig("three_gaussians.pdf")
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.figure(figsize=(8,6))
plt.plot(x,y,label="$e^{-x^2}$",color="orange",linestyle="dashed",linewidth=5)
plt.plot(x,y2,label="$1.5e^{-3(x-2)^2}$",color="#FF0080")
plt.plot(x,y3,label="$0.8e^{-4(x+1)^2}$",linestyle="dotted")
plt.legend()
plt.show()
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.figure(figsize=(8,6))
plt.plot(x,y,label="$e^{-x^2}$",color="orange",linestyle="dashed",linewidth=5)
plt.plot(x,y2,label="$1.5e^{-3(x-2)^2}$",color="#FF0080",marker="*")
plt.plot(x,y3,label="$0.8e^{-4(x+1)^2}$",linestyle="dotted")
plt.legend()
plt.show()
# Fundamentally, plt.plot shows the same marker symbol
# (same size, shape, color) at each data point
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
plt.plot(n,t,marker="o",linestyle="",color="orange")
plt.show()
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
s = np.array([0.1,0.1,0.1,0.2,0.2,0.5])
c = np.array([1,2,3,5,8,20])
plt.scatter(n,t,s=250*s,c=c,marker="o",cmap="Pastel2")
plt.colorbar()
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
s = np.array([0.1,0.1,0.1,0.2,0.2,0.5])
c = np.array([1,2,3,5,8,20])
plt.scatter(n,t,s=250*s,c=c,marker="o",cmap="seismic")
plt.colorbar()
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
s = np.array([0.1,0.1,0.1,0.2,0.2,0.5])
c = np.array(["red","red","red","blue","blue","blue"])
plt.scatter(n,t,s=250*s,c=c,marker="o")
plt.colormaps()
CSV with data about meteorites recovered on earth's surface, adapted from NASA dataset:
import numpy as np
import matplotlib.pyplot as plt
import csv
# This cell just grabs the contents of "meteorites.csv"
# and returns it as a dict mapping column names to vectors of column data
import collections
columns = collections.defaultdict(list)
with open("meteorites.csv","r",newline="") as fp:
rdr = csv.DictReader(fp)
for row in rdr:
for k in row:
columns[k].append(row[k])
for k in columns:
if k != "year":
columns[k] = np.array(columns[k]).astype("float64")
else:
columns[k] = np.array(columns[k]).astype("int")
plt.figure(figsize=(15,10))
plt.scatter(
columns["longitude"],
columns["latitude"],
s=0.002*columns["mass"]**(0.66), # s sets area of the dot
alpha=0.6,
c=columns["year"],
cmap="PuOr"
)
plt.colorbar()
plt.show()
Pretty nice. Looks like a world map. Questions:
plt.figure(figsize=(15,10))
plt.scatter(
columns["longitude"],
columns["latitude"],
s=0.002*columns["mass"]**(0.66), # s sets area of the dot
alpha=0.6,
c=columns["year"],
cmap="PuOr"
)
plt.axis("equal")
plt.xlim(-140,-60)
plt.ylim(10,60)
plt.colorbar()
plt.show()
plt.figure(figsize=(15,10))
plt.scatter(
columns["longitude"],
columns["latitude"],
s=0.002*columns["mass"]**(0.66),
alpha=0.6, # 60% opaque, 40% transparent dots.
c=columns["year"],
cmap="PuOr"
)
plt.colorbar()
plt.annotate("Cape York Meteorite (1818)",
xy=(-64.933,76.13), # Point we're annotating
xycoords='data', # inform matlab these coords are in data units
xytext=(0, 15), # Where the text goes
textcoords='offset points', # inform matlab of units and origin for the coords on prev line
# (units = points, origin = the point being annotated)
horizontalalignment='center',
verticalalignment='bottom',
)
plt.show()
https://en.wikipedia.org/wiki/Cape_York_meteorite#/media/File:Ahnighito_AMNH,_34_tons_meteorite.jpg
x = np.linspace(-3,3,100)
y = np.linspace(-2,2,80)
xx,yy = np.meshgrid(x,y)
# f(x,y) = x**3 - 8x + 3*y**2 + 0.5*y**3
zz = xx**3 - 8*xx + 3*yy**2 + 0.5*yy**3 # 80x100 matrix of values of f on the grid
# f(x,y) = 0.2?
plt.figure(figsize=(8,6))
plt.contour(xx,yy,zz,[0.2])
plt.show()
# Contour plot
plt.figure(figsize=(8,6))
plt.contour(xx,yy,zz)
plt.colorbar()
# Filled contour plot
plt.figure(figsize=(8,6))
plt.contourf(xx,yy,zz)
plt.colorbar()
plt.clabel
adds labels to an existing contour plot. Its argument is the return value of a previous call to plt.contour
.
plt.figure(figsize=(8,6))
contours = plt.contour(xx,yy,zz,15,cmap="magma")
plt.title("Contour plot")
plt.clabel(contours) # add inline labels to the contours
plt.colorbar()
plt.imshow
¶plt.figure(figsize=(8,6))
plt.imshow(zz,extent=[np.min(x),np.max(x),np.min(y),np.max(y)],origin="lower")
# origin="lower" means the first row of zz appears at the bottom of the plot.
# That's correct since our meshgrid has smallest y values in the first row.
plt.title("Density plot")
plt.colorbar()
plt.figure(figsize=(8,6))
contours = plt.contour(xx,yy,zz,15,colors="white")
plt.title("Contour and density plot")
plt.clabel(contours) # add inline labels to the contours
plt.imshow(zz,extent=[np.min(x),np.max(x),np.min(y),np.max(y)],origin="lower")
plt.colorbar()
# This is adapted from an example in the matplotlib docs:
# https://matplotlib.org/stable/gallery/color/named_colors.html
# This uses lots of matplotlib features we don't cover in MCS 275!
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
def plot_colortable(colors, title, sort_colors=True, emptycols=0):
cell_width = 212
cell_height = 48
swatch_width = 64
margin = 12
topmargin = 56
# Sort colors by hue, saturation, value and name.
if sort_colors is True:
by_hsv = sorted((tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(color))),
name)
for name, color in colors.items())
names = [name for hsv, name in by_hsv]
else:
names = list(colors)
n = len(names)
ncols = 4 - emptycols
nrows = n // ncols + int(n % ncols > 0)
width = cell_width * 4 + 2 * margin
height = cell_height * nrows + margin + topmargin
dpi = 72
fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
fig.subplots_adjust(margin/width, margin/height,
(width-margin)/width, (height-topmargin)/height)
ax.set_xlim(0, cell_width * 4)
ax.set_ylim(cell_height * (nrows-0.5), -cell_height/2.)
ax.yaxis.set_visible(False)
ax.xaxis.set_visible(False)
ax.set_axis_off()
ax.set_title(title, fontsize=24, loc="left", pad=10)
for i, name in enumerate(names):
row = i % nrows
col = i // nrows
y = row * cell_height
swatch_start_x = cell_width * col
swatch_end_x = cell_width * col + swatch_width
text_pos_x = cell_width * col + swatch_width + 7
ax.text(text_pos_x, y, name+"\n"+colors[name], fontsize=14,
horizontalalignment='left',
verticalalignment='center')
ax.hlines(y, swatch_start_x, swatch_end_x,
color=colors[name], linewidth=28)
return fig
plot_colortable(
{
"Fire Engine Red":"#D50032",
"Navy Pier Blue":"#001E62"
},
"Primary",
sort_colors=False,
emptycols=1
)
plot_colortable(
{
"Chicago Blue":"#41B6E6",
"Champions Gold":"#FFBF3F",
"UI Health Sky Blue":"#0065AD",
"Expo White":"#F2F7EB",
"Steel Gray":"#333333",
},
"Primary",
sort_colors=False,
emptycols=1
)
plt.show()