Source code for optPilot.visualize

#!/usr/bin/python
# Copyright (c) 2017, Andreas Adelmann, Paul Scherrer Institut, Villigen PSI, Switzerland
# All rights reserved
#
# This file is part of pyOPALTools.
#
# pyOPALTools is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# You should have received a copy of the GNU General Public License
# along with pyOPALTools. If not, see <https://www.gnu.org/licenses/>.
import os
import sys
import glob
import json
import math

import numpy as np
import pylab as pl
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons

from collections import OrderedDict

from optPilot.Annotate import AnnoteFinder

from opal.parser.OptimizerParser import OptimizerParser as jsonreader

# Data parsing
##############################################################################

# name of the columns in the solution data
nameToColumnMap = {}
selected_ids = []
path = ""
videoname = ""
outpath = "./"
filename_postfix = "results.json"
generation = -1
plotAll = False

data = { }

[docs]def readJSONData(filename): dirname = os.path.dirname(filename) optjson = jsonreader() optjson.parse(dirname + '/') # get the generation from the filename basename = os.path.basename(filename) generation = int( str.split(basename, "_", 1)[0] ) optjson.readGeneration(generation) # # make plain format # # build name to column map dvars = optjson.getDesignVariables() objs = optjson.getObjectives() idname = "ID" idx = 0 for name in dvars: nameToColumnMap[name] = idx idx += 1 for name in objs: nameToColumnMap[name] = idx idx += 1 nameToColumnMap[idname] = idx # build data matrix by stacking columns [dvars objsval ids] dvarval = optjson.getAllInput() objsval = optjson.getAllOutput() ids = optjson.getIDs() data = np.column_stack((dvarval, objsval, ids)) return data
[docs]def improveName(name): name.lstrip().rstrip() # remove trailing and leading whitespace name = name.lstrip('%') # remove leading % name = name.replace(" ", "") # remove spaces name = name.replace('_','\_') # latex handling of underscore name = name.replace('\n', '') # remove newlines return name
[docs]def computeLimits(data, selected_ids): xlim = [] ylim = [] xlim.append(1000) xlim.append(-1000) ylim.append(1000) ylim.append(-1000) x_idx = nameToColumnMap[selected_ids[0]] y_idx = nameToColumnMap[selected_ids[1]] for _, d in data.items(): xlim[1] = max(xlim[1], max(d[:, x_idx])) xlim[0] = min(xlim[0], min(d[:, x_idx])) ylim[1] = max(ylim[1], max(d[:, y_idx])) ylim[0] = min(ylim[0], min(d[:, y_idx])) xlim[0] -= 0.05 * (xlim[1] - xlim[0]) xlim[1] += 0.05 * (xlim[1] - xlim[0]) ylim[0] -= 0.05 * (ylim[1] - ylim[0]) ylim[1] += 0.05 * (ylim[1] - ylim[0]) return (xlim, ylim)
[docs]def getXY(generation,path,filename_postfix,selected_ids): fn = path + '/' + str(generation) + '_' + filename_postfix data[str(generation)] = readJSONData(fn) (xlim, ylim) = computeLimits(data, selected_ids) obj1_idx = nameToColumnMap[selected_ids[0]] obj2_idx = nameToColumnMap[selected_ids[1]] x = data[str(generation)][:, obj1_idx] y = data[str(generation)][:, obj2_idx] return x,y
[docs]class Plotter:
[docs] def __init__(self): self.fig, self.ax = plt.subplots()
[docs] def setupPlot(self,width = 1388.5): fig_width_pt = width inches_per_pt = 1.0/72.27 # Convert pt to inch golden_mean = (math.sqrt(5)-1.0)/2.0 # Aesthetic ratio fig_width = fig_width_pt*inches_per_pt # width in inches fig_height = fig_width*golden_mean # height in inches fig_size = [fig_width, fig_height] params = {'backend': 'ps', 'axes.labelsize': 14, 'font.size': 14, 'legend.fontsize': 14, 'xtick.labelsize': 14, 'ytick.labelsize': 14, 'text.usetex': True, 'figure.figsize': fig_size} pl.rcParams.update(params)
[docs] def plot(self, obj): self.obj = obj # self.obj.readData() self.l = plt.plot(obj.getX(),obj.getY(),'*') _vars = obj.get_variables() plt.subplots_adjust(bottom=0.03*(len(_vars)+2)) self.sliders = [] self.buttons = [] for i,var in enumerate(_vars): self.add_slider(i*0.03, var[0], var[1], var[2]) self.add_reset() plt.show()
[docs] def add_reset(self): axcolor = 'lightgoldenrodyellow' ax = plt.axes([0.9, 0.9, 0.1, 0.04]) resbutt = Button(ax, 'Reset', color=axcolor, hovercolor='0.975') self.buttons.append(resbutt) def update(val): self.obj.readInitialData() self.l[0].set_ydata(self.obj.getY()) self.l[0].set_xdata(self.obj.getX()) self.sliders[0].reset() self.fig.canvas.draw_idle() resbutt.on_clicked(update)
[docs] def add_slider(self, pos, name, min, max): ax = plt.axes([0.1, 0.02+pos, 0.8, 0.02], axisbg='lightgoldenrodyellow') slider = Slider(ax, name, min, max, valinit=int(self.obj.generation), valfmt="%d") self.sliders.append(slider) def update(val): self.obj.readData() setattr(self.obj, name, val) self.l[0].set_ydata(self.obj.getY()) self.l[0].set_xdata(self.obj.getX()) self.fig.canvas.draw_idle() slider.on_changed(update)
[docs]class OptData:
[docs] def __init__(self,generation,path,filename_postfix,selected_ids): self.generation=generation self.initialgeneration=generation self.path=path self.filename_postfix=filename_postfix self.selected_ids=selected_ids self.x,self.y = getXY(generation,path,filename_postfix,selected_ids)
[docs] def readData(self): self.x,self.y = getXY(int(self.generation),self.path,self.filename_postfix,self.selected_ids)
[docs] def readInitialData(self): self.x,self.y = getXY(int(self.initialgeneration),self.path,self.filename_postfix,self.selected_ids)
[docs] def getX(self): return self.x
[docs] def getY(self): return self.y
[docs] def get_variables(self): return [ ('generation', 0., 5000.) ]
[docs]def main(argv): for arg in argv: if arg.startswith("--objectives"): objectives = str.split(arg, "=")[1] for obj in str.split(objectives, ","): obj = improveName(obj) selected_ids.append(obj) elif arg.startswith("--dvars"): dvars = str.split(arg, "=")[1] for obj in str.split(dvars, ","): obj = improveName(obj) selected_ids.append(obj) elif arg.startswith("--path"): path = str.split(arg, "=")[1] elif arg.startswith("--filename-postfix"): filename_postfix = str.split(arg, "=")[1] elif arg.startswith("--outpath"): outpath = str.split(arg, "=")[1] elif arg.startswith("--video"): videoname = str.split(arg, "=")[1] elif arg.startswith("--generation"): generation = str.split(arg, "=")[1] elif arg.startswith("--plot-all"): plotAll = True k = Plotter() k.setupPlot() k.plot(OptData(generation,path,filename_postfix,selected_ids))
#call main if __name__ == "__main__": main(sys.argv[1:])