# Copyright (c) 2019, Matthias Frey, Paul Scherrer Institut, Villigen PSI, Switzerland
# All rights reserved
#
# Implemented as part of the PhD thesis
# "Precise Simulations of Multibunches in High Intensity Cyclotrons"
#
# This file is part of pyOPALTools.
#
# pyOPALTools is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# You should have received a copy of the GNU General Public License
# along with pyOPALTools. If not, see <https://www.gnu.org/licenses/>.
from .BasePlotter import *
import numpy as np
from .formatter import FormatScalarFormatter
import os
[docs]class SamplerPlotter(BasePlotter):
[docs] def __init__(self):
pass
[docs] def plot_variability(self, fname, xvar, yvar, **kwargs):
"""Plot the mean, min and max over all samples.
Parameters
----------
fname : str
File containing the data (xvar and yvar)
xvar : str
x-axis data
yvar : str
y-axis data
idx : bool, optional
Fix the x-axis labels (takes the original
data order but uses the indices to plot
and the values as ticks), useful for
periodic values (e.g. azimuth)
nticks : int, optional
Number of ticks on axes (only for idx=True)
Returns
-------
matplotlib.pyplot
Plot handle
"""
try:
from opal import load_dataset
nsamples = self.ds.size
dirname = os.path.dirname(self.ds.filename)
sdir = os.path.join(dirname, str(0))
out = load_dataset(sdir, fname=fname, info=False)
ydata = np.zeros(out.size, dtype=np.float)
ymin = np.finfo(np.float).max + np.zeros(out.size, dtype=np.float)
ymax = np.finfo(np.float).min + np.zeros(out.size, dtype=np.float)
xdata = out.getData(xvar, **kwargs)
nticks = kwargs.pop('nticks', 10)
for i in range(nsamples):
# load simulation directory
sdir = os.path.join(dirname, str(i))
out = load_dataset(sdir, fname=fname, info=False)
data = out.getData(yvar, **kwargs)
ydata += data
ymin = np.minimum(ymin, data)
ymax = np.maximum(ymax, data)
mean = np.zeros(len(ydata), dtype=np.float)
mean = ydata / np.float(nsamples)
if not kwargs.pop('idx', False):
plt.plot(xdata, mean, **kwargs, color='black', linestyle='dashed', label='mean')
plt.fill_between(xdata, ymin, ymax,
facecolor='blue', alpha=0.2, label='variability region')
else:
l = len(xdata)
ind = np.linspace(0, l-1, l, dtype=int)
plt.plot(ind, mean, **kwargs, color='black', linestyle='dashed', label='mean')
plt.fill_between(ind, ymin, ymax,
facecolor='blue', alpha=0.2, label='variability region')
t = int(l / nticks) - 1
plt.xticks(ind[::t], np.round(xdata, 0)[::t].astype(int))
plt.legend(loc = 'upper center',
ncol=2, labelspacing=0.5,
bbox_to_anchor=(0.5, 1.1, 0.0, 0.0))
plt.gca().ticklabel_format(axis='y', style='sci', scilimits=(-2, 2),
useMathText=True, useOffset=True)
plt.xlabel(self.ds.getLabelWithUnit(xvar))
plt.ylabel(self.ds.getLabelWithUnit(yvar))
plt.tight_layout()
return plt
except Exception as ex:
opal_logger.exception(ex)
return plt.figure()
[docs] def plot_auto_correlation(self, ind, **kwargs):
"""Compare a sample set with itself.
Parameters
----------
ind : list
Indices of the sample set.
nsamples : bool, optional
Show a horizontal line indicating
the total number of samples
percent : bool, optional
Indicate the agreement in percent
above each bar
Returns
-------
matplotlib.pyplot
Plot handle
"""
try:
import matplotlib as mpl
percent = kwargs.pop('percent', False)
nsamples = len(ind)
matches = []
ntrain = 1
scale = 1.0
if percent:
scale = float(nsamples) * 0.01
while not ntrain == nsamples:
sample = ind[0:ntrain]
diff = ind[ntrain:]
matches.append( self.ds.find_matches(sample, diff) / scale )
ntrain += 1
plt.plot(np.arange(nsamples - 1), matches)
ylabel = '#identical samples'
if percent:
ylabel = ylabel + ' in %'
plt.xlabel('#samples')
plt.ylabel(ylabel)
if kwargs.pop('nsamples', False) and not percent:
plt.axhline(nsamples, linestyle='dashed', label='#samples')
plt.legend(loc = 'upper center', ncol=1, labelspacing=0.,
bbox_to_anchor=(0.5, 1.1, 0.0, 0.0))
plt.tight_layout()
return plt
except Exception as ex:
opal_logger.exception(ex)
return plt.figure()
[docs] def plot_training_vs_validation(self, train0, **kwargs):
"""Bar plot comparing training with validation set.
Parameters
----------
train0 : list
Indices of the training points.
train1 : list, optional
More lists with indices
train2, train3, etc. are also keywords
nsamples : bool, optional
Show a horizontal line indicating
the total number of samples
percent : bool, optional
Indicate the agreement in percent
above each bar
Returns
-------
matplotlib.pyplot
Plot handle
"""
try:
import matplotlib as mpl
nsamples = self.ds.size
trains = [train0]
i = 1
while True:
train = kwargs.pop('train' + str(i), None)
if train == None:
break
trains.append(train)
i += 1
matches = []
ntrains = []
for train in trains:
ntrain = len(train)
ntrains.append( ntrain )
if ntrain >= nsamples:
opal_logger.error('ntrain (' + str(ntrain) + ') >= ' +
'nsamples (' + str(nsamples) + ')')
validation = np.arange(nsamples, dtype=int)
# get all indices not in training sample set
# 12. April 2019
# https://stackoverflow.com/questions/3428536/python-list-subtraction-operation
validation = [int(i) for i in validation if int(i) not in train]
matches.append( self.ds.find_matches(train, validation) )
ind = np.arange(len(ntrains))
bars = plt.bar(ind, matches)
plt.xticks(ind, ntrains)
plt.xlabel('#training samples')
plt.ylabel('#identical samples with validation set')
topline = kwargs.pop('nsamples', False)
if topline:
plt.axhline(nsamples, linestyle='dashed', label='#samples')
if kwargs.pop('percent', True):
# 12. April 2019
# https://matplotlib.org/gallery/lines_bars_and_markers/barchart.html#sphx-glr-gallery-lines-bars-and-markers-barchart-py
for rect in bars:
height = rect.get_height()
plt.text(rect.get_x() + rect.get_width()*0.5, 1.01*height,
'{}'.format(height * 100.0 / nsamples) + '%', ha='center', va='bottom')
#self._autolabel(plt.gca(), bars, 'center')
if topline:
plt.legend(loc = 'upper center', ncol=1, labelspacing=0.,
bbox_to_anchor=(0.5, 1.1, 0.0, 0.0))
plt.tight_layout()
return plt
except Exception as ex:
opal_logger.exception(ex)
return plt.figure()
# 12. April 2019
# https://matplotlib.org/gallery/lines_bars_and_markers/barchart.html#sphx-glr-gallery-lines-bars-and-markers-barchart-py
[docs] def _autolabel(self, ax, rects, xpos='center'):
"""Attach a text label above each bar in *rects*, displaying its height.
Copied from matplotlib.org. It's slightly modified.
`xpos` indicates which side to place the text w.r.t. the center of
the bar. It can be one of the following {'center', 'right', 'left'}.
"""
xpos = xpos.lower() # normalize the case of the parameter
ha = {'center': 'center', 'right': 'left', 'left': 'right'}
offset = {'center': 0.5, 'right': 0.57, 'left': 0.43} # x_txt = x + w*off
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x() + rect.get_width()*offset[xpos], 1.01*height,
'{}'.format(height), ha=ha[xpos], va='bottom')