from builtins import str
from builtins import object
import os, warnings
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine import url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, Float
from sqlalchemy import ForeignKey
from sqlalchemy.orm import relationship, backref
from sqlalchemy.exc import DatabaseError
from lsst.daf.persistence import DbAuth
import numpy as np
Base = declarative_base()
__all__ = ['MetricRow', 'DisplayRow', 'PlotRow', 'SummaryStatRow', 'ResultsDb']
[docs]class MetricRow(Base):
"""
Define contents and format of metric list table.
(Table to list all metrics, their metadata, and their output data files).
"""
__tablename__ = "metrics"
# Define columns in metric list table.
metricId = Column(Integer, primary_key=True)
metricName = Column(String)
slicerName = Column(String)
simDataName = Column(String)
sqlConstraint = Column(String)
metricMetadata = Column(String)
metricDataFile = Column(String)
def __repr__(self):
return "<Metric(metricId='%d', metricName='%s', slicerName='%s', simDataName='%s', sqlConstraint='%s', metadata='%s', metricDataFile='%s')>" \
%(self.metricId, self.metricName, self.slicerName, self.simDataName,
self.sqlConstraint, self.metricMetadata, self.metricDataFile)
[docs]class DisplayRow(Base):
"""
Define contents and format of the displays table.
(Table to list the display properties for each metric.)
"""
__tablename__ = "displays"
displayId = Column(Integer, primary_key=True)
metricId = Column(Integer, ForeignKey('metrics.metricId'))
# Group for displaying metric (in webpages).
displayGroup = Column(String)
# Subgroup for displaying metric.
displaySubgroup = Column(String)
# Order to display metric (within subgroup).
displayOrder = Column(Float)
# The figure caption.
displayCaption = Column(String)
metric = relationship("MetricRow", backref=backref('displays', order_by=displayId))
def __rep__(self):
return "<Display(displayGroup='%s', displaySubgroup='%s', displayOrder='%.1f', displayCaption='%s')>" \
%(self.displayGroup, self.displaySubgroup, self.displayOrder, self.displayCaption)
[docs]class PlotRow(Base):
"""
Define contents and format of plot list table.
(Table to list all plots, link them to relevant metrics in MetricList, and provide info on filename).
"""
__tablename__ = "plots"
# Define columns in plot list table.
plotId = Column(Integer, primary_key=True)
# Matches metricID in MetricList table.
metricId = Column(Integer, ForeignKey('metrics.metricId'))
plotType = Column(String)
plotFile = Column(String)
metric = relationship("MetricRow", backref=backref('plots', order_by=plotId))
def __repr__(self):
return "<Plot(metricId='%d', plotType='%s', plotFile='%s')>" \
%(self.metricId, self.plotType, self.plotFile)
[docs]class SummaryStatRow(Base):
"""
Define contents and format of the summary statistics table.
(Table to list and link summary stats to relevant metrics in MetricList, and provide summary stat name,
value and potentially a comment).
"""
__tablename__ = "summarystats"
# Define columns in plot list table.
statId = Column(Integer, primary_key=True)
# Matches metricID in MetricList table.
metricId = Column(Integer, ForeignKey('metrics.metricId'))
summaryName = Column(String)
summaryValue = Column(Float)
metric = relationship("MetricRow", backref=backref('summarystats', order_by=statId))
def __repr__(self):
return "<SummaryStat(metricId='%d', summaryName='%s', summaryValue='%f')>" \
%(self.metricId, self.summaryName, self.summaryValue)
[docs]class ResultsDb(object):
def __init__(self, outDir= None, database=None, driver='sqlite',
host=None, port=None, verbose=False):
"""
Instantiate the results database, creating metrics, plots and summarystats tables.
"""
# Connect to database
# for sqlite, connecting to non-existent database creates it automatically
if database is None:
# Using default value for database name, should specify directory.
if outDir is None:
outDir = '.'
# Check for output directory, make if needed.
if not os.path.isdir(outDir):
try:
os.makedirs(outDir)
except OSError as msg:
raise OSError(msg, '\n (If this was the database file (not outDir), '
'remember to use kwarg "database")')
self.database = os.path.join(outDir, 'resultsDb_sqlite.db')
self.driver = 'sqlite'
else:
if driver == 'sqlite':
# Using non-default database, but may also specify directory root.
if outDir is not None:
database = os.path.join(outDir, database)
self.database = database
self.driver = driver
else:
# If not sqlite, then 'outDir' doesn't make much sense.
self.database = database
self.driver = driver
self.host = host
self.port = port
if self.driver == 'sqlite':
dbAddress = url.URL(self.driver, database=self.database)
else:
dbAddress = url.URL(self.driver,
username=DbAuth.username(self.host, str(self.port)),
password=DbAuth.password(self.host, str(self.port)),
host=self.host,
port=self.port,
database=self.database)
engine = create_engine(dbAddress, echo=verbose)
self.Session = sessionmaker(bind=engine)
self.session = self.Session()
# Create the tables, if they don't already exist.
try:
Base.metadata.create_all(engine)
except DatabaseError:
raise ValueError("Cannot create a %s database at %s. Check directory exists." %(self.driver, self.database))
self.slen = 1024
[docs] def close(self):
"""
Close connection to database.
"""
self.session.close()
[docs] def updateMetric(self, metricName, slicerName, simDataName, sqlConstraint,
metricMetadata, metricDataFile):
"""
Add a row to or update a row in the metrics table.
- metricName: the name of the metric
- sliceName: the name of the slicer
- simDataName: the name used to identify the simData
- sqlConstraint: the sql constraint used to select data from the simData
- metricMetadata: the metadata associated with the metric
- metricDatafile: the data file the metric data is stored in
If same metric (same metricName, slicerName, simDataName, sqlConstraint, metadata)
already exists, it does nothing.
Returns metricId: the Id number of this metric in the metrics table.
"""
if simDataName is None:
simDataName = 'NULL'
if sqlConstraint is None:
sqlConstraint = 'NULL'
if metricMetadata is None:
metricMetadata = 'NULL'
if metricDataFile is None:
metricDataFile = 'NULL'
# Check if metric has already been added to database.
prev = self.session.query(MetricRow).filter_by(metricName=metricName,
slicerName=slicerName,
simDataName=simDataName,
metricMetadata=metricMetadata,
sqlConstraint=sqlConstraint).all()
if len(prev) == 0:
metricinfo = MetricRow(metricName=metricName, slicerName=slicerName, simDataName=simDataName,
sqlConstraint=sqlConstraint, metricMetadata=metricMetadata,
metricDataFile=metricDataFile)
self.session.add(metricinfo)
self.session.commit()
else:
metricinfo = prev[0]
return metricinfo.metricId
[docs] def updateDisplay(self, metricId, displayDict, overwrite=True):
"""
Add a row to or update a row in the displays table.
- metricID: the metric Id of this metric in the metrics table
- displayDict: dictionary containing the display info
Replaces existing row with same metricId.
"""
# Because we want to maintain 1-1 relationship between metricId's and displayDict's:
# First check if a display line is present with this metricID.
displayinfo = self.session.query(DisplayRow).filter_by(metricId=metricId).all()
if len(displayinfo) > 0:
if overwrite:
for d in displayinfo:
self.session.delete(d)
else:
return
# Then go ahead and add new displayDict.
for k in displayDict:
if displayDict[k] is None:
displayDict[k] = 'NULL'
keys = ['group', 'subgroup', 'order', 'caption']
for k in keys:
if k not in displayDict:
displayDict[k] = 'NULL'
if displayDict['order'] == 'NULL':
displayDict['order'] = 0
displayGroup = displayDict['group']
displaySubgroup = displayDict['subgroup']
displayOrder = displayDict['order']
displayCaption = displayDict['caption']
if displayCaption.endswith('(auto)'):
displayCaption = displayCaption.replace('(auto)', '', 1)
displayinfo = DisplayRow(metricId=metricId,
displayGroup=displayGroup, displaySubgroup=displaySubgroup,
displayOrder=displayOrder, displayCaption=displayCaption)
self.session.add(displayinfo)
self.session.commit()
[docs] def updatePlot(self, metricId, plotType, plotFile):
"""
Add a row to or update a row in the plot table.
- metricId: the metric Id of this metric in the metrics table
- plotType: the 'type' of this plot
- plotFile: the filename of this plot
Remove older rows with the same metricId, plotType and plotFile.
"""
plotinfo = self.session.query(PlotRow).filter_by(metricId=metricId, plotType=plotType,
plotFile=plotFile).all()
if len(plotinfo) > 0:
for p in plotinfo:
self.session.delete(p)
plotinfo = PlotRow(metricId=metricId, plotType=plotType, plotFile=plotFile)
self.session.add(plotinfo)
self.session.commit()
[docs] def updateSummaryStat(self, metricId, summaryName, summaryValue):
"""
Add a row to or update a row in the summary statistic table.
- metricId: the metric ID of this metric in the metrics table
- summaryName: the name of this summary statistic
- summaryValue: the value for this summary statistic
Most summary statistics will be a simple name (string) + value (float) pair.
For special summary statistics which must return multiple values, the base name
can be provided as 'name', together with a np recarray as 'value', where the
recarray also has 'name' and 'value' columns (and each name/value pair is then saved
as a summary statistic associated with this same metricId).
"""
# Allow for special summary statistics which return data in a np structured array with
# 'name' and 'value' columns. (specificially needed for TableFraction summary statistic).
if isinstance(summaryValue, np.ndarray):
if (('name' in summaryValue.dtype.names) and ('value' in summaryValue.dtype.names)):
for value in summaryValue:
sSuffix = value['name']
if isinstance(sSuffix, bytes):
sSuffix = sSuffix.decode('utf-8')
else:
sSuffix = str(sSuffix)
summarystat = SummaryStatRow(metricId=metricId,
summaryName=summaryName + ' ' + sSuffix,
summaryValue=value['value'])
self.session.add(summarystat)
self.session.commit()
else:
warnings.warn('Warning! Cannot save non-conforming summary statistic.')
# Most summary statistics will be simple floats.
else:
if isinstance(summaryValue, float) or isinstance(summaryValue, int):
summarystat = SummaryStatRow(metricId=metricId, summaryName=summaryName,
summaryValue=summaryValue)
self.session.add(summarystat)
self.session.commit()
else:
warnings.warn('Warning! Cannot save summary statistic that is not a simple float or int')
[docs] def getMetricId(self, metricName, slicerName=None, metricMetadata=None, simDataName=None):
"""
Given a metric name and optional slicerName/metricMetadata/simData information,
Return a list of the matching metricIds.
"""
metricId = []
query = self.session.query(MetricRow.metricId, MetricRow.metricName, MetricRow.slicerName,
MetricRow.metricMetadata,
MetricRow.simDataName).filter(MetricRow.metricName == metricName)
if slicerName is not None:
query = query.filter(MetricRow.slicerName == slicerName)
if metricMetadata is not None:
query = query.filter(MetricRow.metricMetadata == metricMetadata)
if simDataName is not None:
query = query.filter(MetricRow.simDataName == simDataName)
query = query.order_by(MetricRow.slicerName, MetricRow.metricMetadata)
for m in query:
metricId.append(m.metricId)
return metricId
[docs] def getMetricIdLike(self, metricNameLike=None, slicerNameLike=None,
metricMetadataLike=None, simDataName=None):
metricId = []
query = self.session.query(MetricRow.metricId, MetricRow.metricName, MetricRow.slicerName,
MetricRow.metricMetadata,
MetricRow.simDataName)
if metricNameLike is not None:
query = query.filter(MetricRow.metricName.like('%' + str(metricNameLike) + '%'))
if slicerNameLike is not None:
query = query.filter(MetricRow.slicerName.like('%' + str(slicerNameLike) + '%'))
if metricMetadataLike is not None:
query = query.filter(MetricRow.metricMetadata.like('%' + str(metricMetadataLike) + '%'))
if simDataName is not None:
query = query.filter(MetricRow.simDataName == simDataName)
for m in query:
metricId.append(m.metricId)
return metricId
[docs] def getAllMetricIds(self):
"""
Return a list of all metricIds.
"""
metricIds = []
for m in self.session.query(MetricRow.metricId).all():
metricIds.append(m.metricId)
return metricIds
[docs] def getSummaryStats(self, metricId=None, summaryName=None):
"""
Get the summary stats (optionally for metricId list).
Optionally, also specify the summary metric name.
Returns a numpy array of the metric information + summary statistic information.
"""
if metricId is None:
metricId = self.getAllMetricIds()
if not hasattr(metricId, '__iter__'):
metricId = [metricId,]
summarystats = []
for mid in metricId:
# Join the metric table and the summarystat table, based on the metricID (the second filter)
query = (self.session.query(MetricRow, SummaryStatRow).filter(MetricRow.metricId == mid)
.filter(MetricRow.metricId == SummaryStatRow.metricId))
if summaryName is not None:
query = query.filter(SummaryStatRow.summaryName == summaryName)
for m, s in query:
summarystats.append((m.metricId, m.metricName, m.slicerName, m.metricMetadata,
s.summaryName, s.summaryValue))
# Convert to numpy array.
dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen),
('slicerName', np.str_, self.slen), ('metricMetadata', np.str_, self.slen),
('summaryName', np.str_, self.slen), ('summaryValue', float)])
summarystats = np.array(summarystats, dtype)
return summarystats
[docs] def getPlotFiles(self, metricId=None):
"""
Return the metricId, name, metadata, and all plot info (optionally for metricId list).
Returns a numpy array of the metric information + plot file names.
"""
if metricId is None:
metricId = self.getAllMetricIds()
if not hasattr(metricId, '__iter__'):
metricId = [metricId,]
plotFiles = []
for mid in metricId:
# Join the metric table and the plot table based on the metricID (the second filter does the join)
query = (self.session.query(MetricRow, PlotRow).filter(MetricRow.metricId == mid)
.filter(MetricRow.metricId == PlotRow.metricId))
for m, p in query:
thumbfile = 'thumb.' + ''.join(p.plotFile.split('.')[:-1]) + '.png'
plotFiles.append((m.metricId, m.metricName, m.metricMetadata,
p.plotType, p.plotFile, thumbfile))
# Convert to numpy array.
dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen),
('metricMetadata', np.str_, self.slen),
('plotType', np.str_, self.slen), ('plotFile', np.str_, self.slen),
('thumbFile', np.str_, self.slen)])
plotFiles = np.array(plotFiles, dtype)
return plotFiles
[docs] def getMetricDataFiles(self, metricId=None):
"""
Get the metric data filenames for all or a single metric.
Returns a list.
"""
if metricId is None:
metricId = self.getAllMetricIds()
if not hasattr(metricId, '__iter__'):
metricId = [metricId,]
dataFiles = []
for mid in metricId:
for m in self.session.query(MetricRow).filter(MetricRow.metricId == mid).all():
dataFiles.append(m.metricDataFile)
return dataFiles
[docs] def getMetricDisplayInfo(self, metricId=None):
"""
Get the contents of the metrics and displays table, together with the 'basemetricname'
(optionally, for metricId list).
Returns a numpy array of the metric information + display information.
"""
if metricId is None:
metricId = self.getAllMetricIds()
if not hasattr(metricId, '__iter__'):
metricId = [metricId,]
metricInfo = []
for mId in metricId:
# Query for all rows in metrics and displays that match any of the metricIds.
query = (self.session.query(MetricRow, DisplayRow).filter(MetricRow.metricId==mId)
.filter(MetricRow.metricId==DisplayRow.metricId))
for m, d in query:
baseMetricName = m.metricName.split('_')[0]
mInfo = (m.metricId, m.metricName, baseMetricName, m.slicerName,
m.sqlConstraint, m.metricMetadata, m.metricDataFile,
d.displayGroup, d.displaySubgroup, d.displayOrder, d.displayCaption)
metricInfo.append(mInfo)
# Convert to numpy array.
dtype = np.dtype([('metricId', int), ('metricName', np.str_, self.slen),
('baseMetricNames', np.str_, self.slen),
('slicerName', np.str_, self.slen),
('sqlConstraint', np.str_, self.slen),
('metricMetadata', np.str_, self.slen),
('metricDataFile', np.str_, self.slen),
('displayGroup', np.str_, self.slen),
('displaySubgroup', np.str_, self.slen),
('displayOrder', float),
('displayCaption', np.str_, self.slen * 10)])
metricInfo = np.array(metricInfo, dtype)
return metricInfo