#! /usr/bin/env python

"""\
%prog [-o outfile] <datafile1> <datafile2>

Compare analysis objects between two YODA-readable data files.
"""

import yoda, sys, optparse

parser = optparse.OptionParser(usage=__doc__)
parser.add_option('-o', '--output', default='-', dest='OUTPUT_FILE')
parser.add_option('-t', '--tol', type=float, default=1e-5, dest='TOL', )
opts, filenames = parser.parse_args()

if len(filenames) != 2:
    print "ERROR! Please supply *two* YODA files for comparison"
    sys.exit(1)

## Get data objects
aodict1 = yoda.read(filenames[0])
aodict2 = yoda.read(filenames[1])

## Check number of data objects in each file
if len(aodict1) != len(aodict2):
    print "Different numbers of data objects in %s and %s" % tuple(filenames[:2])
elif aodict1.keys() != aodict2.keys():
    print "Different data object paths in %s and %s" % tuple(filenames[:2])

## A slightly tolerant numerical comparison function
def eq(a, b):
    if a == b:
        return True
    from math import isnan
    if isnan(a) and isnan(b):
        return True
    ## Type-check: be a bit careful re. int vs. float
    if type(a) is not type(b) and not all(type(x) in (int, float) for x in (a,b)):
        return False
    ## Recursively call on pairs of components if a and b are iterables
    if hasattr(a, "__iter__"):
        return all(eq(*pair) for pair in zip(a, b))
    ## Finally apply a tolerant numerical comparison on numeric types
    # TODO: Be careful with values on either side of zero. Add abs tolerance "epsilon" term?
    return (float(a) - float(b))/(float(a) + float(b)) < opts.TOL

def s2str(s):
    ptstr1 = "{x:.2g} + {ex[0]:.2g} - {ex[1]:.2g}"
    ptstr2 = "{x:.2g} +- {ex[0]:.2g}"
    xstr = (ptstr2 if eq(*s.xErrs) else ptstr1).format(x=s.x, ex=s.xErrs)
    ystr = (ptstr2 if eq(*s.yErrs) else ptstr1).format(x=s.y, ex=s.yErrs)
    return "({}, {})".format(xstr, ystr)

## Compare each object pair
for path in sorted(set(aodict1.keys() + aodict2.keys())):
    ## Get the object in file #1
    ao1 = aodict1.get(path, None)
    if ao1 is None:
        print "Data object '%s' not found in %s" % (path, filenames[0])
        continue
    ## Get the object in file #2
    ao2 = aodict2.get(path, None)
    if ao2 is None:
        print "Data object '%s' not found in %s" % (path, filenames[1])
        continue
    ## Compare the file #1 vs. #2 object types
    if type(ao1) is not type(ao2):
        print "Data objects with path '%s' have different types (%s and %s) in %s and %s" % \
              (path, str(type(ao1)), str(type(ao2)), filenames[0], filenames[1])
        continue
    ## Compare the numbers of points/bins
    # TODO: Fix the mapping so the exception can be caught in Python
    try:
        s1 = ao1.mkScatter() if type(ao1) is not yoda.Scatter2D else ao1
        s2 = ao2.mkScatter() if type(ao2) is not yoda.Scatter2D else ao2
    except Exception, e:
        print "WARNING! Could not create a '%s' scatter for comparison (%s)" % (path, type(e).__name__)
    if s1.numPoints != s2.numPoints:
        print "Warning: data objects with path '%s' have different numbers of points (%d and %d) in %s and %s" % \
              (path, s1.numPoints, s2.numPoints, filenames[0], filenames[1])
        continue
    ## Compare the numeric values of each point
    premsg = "\nData points differ for data objects with path '%s' in %s and %s:\n" % (path, filenames[0], filenames[1])
    msgs = []
    for i, (p1, p2) in enumerate(zip(s1.points, s2.points)):
        if not eq(p1.x, p2.x) or not eq(p1.xErrs, p2.xErrs) or not eq(p1.y, p2.y) or not eq(p1.yErrs, p2.yErrs):
            msgs.append("  Point #%d: %s vs. %s" % (i, s2str(p1), s2str(p2)))
    if msgs:
        print premsg + "\n".join(msgs)
