#!/usr/bin/python
# Copyright (c) 2017, Andreas Adelmann, Paul Scherrer Institut, Villigen PSI, Switzerland
# All rights reserved
#
# This file is part of pyOPALTools.
#
# pyOPALTools is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# You should have received a copy of the GNU General Public License
# along with pyOPALTools. If not, see <https://www.gnu.org/licenses/>.
import os
import sys
import glob
import json
import math
import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
from collections import OrderedDict
from optPilot.Annotate import AnnoteFinder
from opal.parser.OptimizerParser import OptimizerParser as jsonreader
# Data parsing
##############################################################################
# name of the columns in the solution data
nameToColumnMap = {}
selected_ids = []
path = ""
videoname = ""
outpath = "./"
filename_postfix = "results.json"
generation = -1
plotAll = False
data = { }
[docs]def readJSONData(filename):
dirname = os.path.dirname(filename)
optjson = jsonreader()
optjson.parse(dirname + '/')
# get the generation from the filename
basename = os.path.basename(filename)
generation = int( str.split(basename, "_", 1)[0] )
optjson.readGeneration(generation)
#
# make plain format
#
# build name to column map
dvars = optjson.getDesignVariables()
objs = optjson.getObjectives()
idname = "ID"
idx = 0
for name in dvars:
nameToColumnMap[name] = idx
idx += 1
for name in objs:
nameToColumnMap[name] = idx
idx += 1
nameToColumnMap[idname] = idx
# build data matrix by stacking columns [dvars objsval ids]
dvarval = optjson.getAllInput()
objsval = optjson.getAllOutput()
ids = optjson.getIDs()
data = np.column_stack((dvarval, objsval, ids))
return data
[docs]def improveName(name):
name.lstrip().rstrip() # remove trailing and leading whitespace
name = name.lstrip('%') # remove leading %
name = name.replace(" ", "") # remove spaces
name = name.replace('_','\_') # latex handling of underscore
name = name.replace('\n', '') # remove newlines
return name
[docs]def computeLimits(data, selected_ids):
xlim = []
ylim = []
xlim.append(1000)
xlim.append(-1000)
ylim.append(1000)
ylim.append(-1000)
x_idx = nameToColumnMap[selected_ids[0]]
y_idx = nameToColumnMap[selected_ids[1]]
for _, d in data.items():
xlim[1] = max(xlim[1], max(d[:, x_idx]))
xlim[0] = min(xlim[0], min(d[:, x_idx]))
ylim[1] = max(ylim[1], max(d[:, y_idx]))
ylim[0] = min(ylim[0], min(d[:, y_idx]))
xlim[0] -= 0.05 * (xlim[1] - xlim[0])
xlim[1] += 0.05 * (xlim[1] - xlim[0])
ylim[0] -= 0.05 * (ylim[1] - ylim[0])
ylim[1] += 0.05 * (ylim[1] - ylim[0])
return (xlim, ylim)
[docs]def getXY(generation,path,filename_postfix,selected_ids):
fn = path + '/' + str(generation) + '_' + filename_postfix
data[str(generation)] = readJSONData(fn)
(xlim, ylim) = computeLimits(data, selected_ids)
obj1_idx = nameToColumnMap[selected_ids[0]]
obj2_idx = nameToColumnMap[selected_ids[1]]
x = data[str(generation)][:, obj1_idx]
y = data[str(generation)][:, obj2_idx]
return x,y
[docs]class Plotter:
[docs] def __init__(self):
self.fig, self.ax = plt.subplots()
[docs] def setupPlot(self,width = 1388.5):
fig_width_pt = width
inches_per_pt = 1.0/72.27 # Convert pt to inch
golden_mean = (math.sqrt(5)-1.0)/2.0 # Aesthetic ratio
fig_width = fig_width_pt*inches_per_pt # width in inches
fig_height = fig_width*golden_mean # height in inches
fig_size = [fig_width, fig_height]
params = {'backend': 'ps',
'axes.labelsize': 14,
'font.size': 14,
'legend.fontsize': 14,
'xtick.labelsize': 14,
'ytick.labelsize': 14,
'text.usetex': True,
'figure.figsize': fig_size}
pl.rcParams.update(params)
[docs] def plot(self, obj):
self.obj = obj
# self.obj.readData()
self.l = plt.plot(obj.getX(),obj.getY(),'*')
_vars = obj.get_variables()
plt.subplots_adjust(bottom=0.03*(len(_vars)+2))
self.sliders = []
self.buttons = []
for i,var in enumerate(_vars):
self.add_slider(i*0.03, var[0], var[1], var[2])
self.add_reset()
plt.show()
[docs] def add_reset(self):
axcolor = 'lightgoldenrodyellow'
ax = plt.axes([0.9, 0.9, 0.1, 0.04])
resbutt = Button(ax, 'Reset', color=axcolor, hovercolor='0.975')
self.buttons.append(resbutt)
def update(val):
self.obj.readInitialData()
self.l[0].set_ydata(self.obj.getY())
self.l[0].set_xdata(self.obj.getX())
self.sliders[0].reset()
self.fig.canvas.draw_idle()
resbutt.on_clicked(update)
[docs] def add_slider(self, pos, name, min, max):
ax = plt.axes([0.1, 0.02+pos, 0.8, 0.02], axisbg='lightgoldenrodyellow')
slider = Slider(ax, name, min, max, valinit=int(self.obj.generation), valfmt="%d")
self.sliders.append(slider)
def update(val):
self.obj.readData()
setattr(self.obj, name, val)
self.l[0].set_ydata(self.obj.getY())
self.l[0].set_xdata(self.obj.getX())
self.fig.canvas.draw_idle()
slider.on_changed(update)
[docs]class OptData:
[docs] def __init__(self,generation,path,filename_postfix,selected_ids):
self.generation=generation
self.initialgeneration=generation
self.path=path
self.filename_postfix=filename_postfix
self.selected_ids=selected_ids
self.x,self.y = getXY(generation,path,filename_postfix,selected_ids)
[docs] def readData(self):
self.x,self.y = getXY(int(self.generation),self.path,self.filename_postfix,self.selected_ids)
[docs] def readInitialData(self):
self.x,self.y = getXY(int(self.initialgeneration),self.path,self.filename_postfix,self.selected_ids)
[docs] def getX(self):
return self.x
[docs] def getY(self):
return self.y
[docs] def get_variables(self):
return [
('generation', 0., 5000.)
]
[docs]def main(argv):
for arg in argv:
if arg.startswith("--objectives"):
objectives = str.split(arg, "=")[1]
for obj in str.split(objectives, ","):
obj = improveName(obj)
selected_ids.append(obj)
elif arg.startswith("--dvars"):
dvars = str.split(arg, "=")[1]
for obj in str.split(dvars, ","):
obj = improveName(obj)
selected_ids.append(obj)
elif arg.startswith("--path"):
path = str.split(arg, "=")[1]
elif arg.startswith("--filename-postfix"):
filename_postfix = str.split(arg, "=")[1]
elif arg.startswith("--outpath"):
outpath = str.split(arg, "=")[1]
elif arg.startswith("--video"):
videoname = str.split(arg, "=")[1]
elif arg.startswith("--generation"):
generation = str.split(arg, "=")[1]
elif arg.startswith("--plot-all"):
plotAll = True
k = Plotter()
k.setupPlot()
k.plot(OptData(generation,path,filename_postfix,selected_ids))
#call main
if __name__ == "__main__": main(sys.argv[1:])