#! /usr/bin/env python

"""\
%prog [-o outfile] <yodafile1>[:<scale1>] <yodafile2>[:<scale1>] ...
  e.g. %prog run1.yoda run2.yoda run3.yoda  (unweighted merging of three runs)
    or %prog run1.yoda:2.0 run2.yoda:3.142  (weighted merging of two runs)

Merge analysis objects from multiple YODA files, combining the statistics of
objects whose names are found in multiple files. May be used either to merge
disjoint collections of data objects, or to combine multiple statistically
independent runs of the same data objects into one high-statistics run. Optional
scaling parameters may be given to rescale the weights of the objects on a
per-file basis before merging.

By default the output is written to stdout since we can't guess what would be
a good automatic filename choice! Use the -o option to provide an output filename.


IMPORTANT!
  This script is not meant to handle all run merging situations or data objects:
  there are limitations to what can be inferred from data objects alone. If you
  need to do something more complex than the common cases handled by this script,
  please write your own script / program to load and process the data objects.


SCATTERS (E.G. HISTOGRAM RATIOS) CAN'T BE MERGED

  Note that 'scatter' data objects, as opposed to histograms, cannot be merged
  by this tool since they do not preserve sufficient statistical
  information. The canonical example of this is a ratio plot: there are
  infinitely many combinations of numerator and denominator which could give the
  same ratio, and the result does not indicate anything about which of those
  infinite inputs is right (or the effects of correlations in the division).

  If you need to merge Scatter2D objects, you can write your own Python script
  or C++ program using the YODA interface, and apply whatever case-specific
  treatment is appropriate. By default the first such copy encountered will be
  returned as the 'merged' output, with no actual merging having been done.

NORMALIZED, UNNORMALIZED, OR A MIX?

  An important detail in histogram merging is whether a statistical treatment
  for normalized or unnormalized histograms should be used: in the former case
  the normalization scaling must be undone *before* the histograms are added
  together, and then re-applied afterwards. By default this script attempts to
  work out whether histograms with a common name have been normalized by
  checking whether their normalizations are all the same.

  This is not an infallible approach -- in particular if the normalization is to
  a quantity which might fluctuate between runs (a Monte Carlo cross-section
  estimate, for example) then the inputs might be incorrectly identified as
  unnormalized. The --assume-normalized option, which bypasses the 'common norm'
  heuristic, may be used if you know that all the histograms in the file you are
  attempting to merge are normalized... but it will fail on files which contain
  a mixture of normalized and unnormalized plots (the latter do not have
  ScaledBy attributes).

  In such complicated situations you will again be better off writing your own
  script or program to do the merging.

See the source of this script (e.g. use 'less `which %prog`) for more discussion.
"""

# MORE NOTES
#
# If all the input histograms with a particular path are found to have the same
# normalization, and they have ScaledBy attributes indicating that a histogram
# weight scaling has been applied in producing the input histograms, each
# histogram in that group will be first unscaled by their appropriate factor, then
# merged, and then re-normalized to the target value. Otherwise the weights from
# each histogram copy will be directly added together with no attempt to guess an
# appropriate normalization. The normalization guesses (and they are guesses --
# see below) are made *before* application of the per-file scaling arguments.
#
# IMPORTANT: note from the above that this script can't work out what to do
# re. scaling and normalization of output histograms from the input data files
# alone. It may be possible (although unlikely) that input histograms have the
# same normalization but are meant to be added directly. It may also be the case
# (and much more likely) that histograms which should be normalized to a common
# value will not trigger the appropriate treatment due to e.g. statistical
# fluctuations in each run's calculation of a cross-section used in the
# normalization. And anything more complex than a global scaling (e.g. calculation
# of a ratio or asymmetry) cannot be handled at all with a post-hoc scaling
# treatment. The --assume-normalized command line option will force all histograms
# to be treated as if they are normalized in the input, which can be useful if
# you know that all the output histograms are indeed of this nature. If they are
# not, it will go wrong: you have been warned!
#
# Please use this script as a template if you need to do something more specific.
#
# NOTE: there are many possible desired behaviours when merging runs, depending on
# the factors above as well as whether the files being merged are of homogeneous
# type, heterogeneous type, or a combination of both. It is tempting, therefore,
# to add a large number of optional command-line parameters to this script, to
# handle these cases. Experience from Rivet 1.x suggests that this is a bad idea:
# if a problem is of programmatic complexity then a command-line interface which
# attempts to solve it in general is doomed to both failure and unusability. Hence
# we will NOT add extra arguments for applying different merging weights or
# strategies based on analysis object path regexes, auto-identifying 'types' of
# run, etc., etc.: if you need to merge data files in such complex ways, please
# use this script as a template around which to write logic that satisfies your
# particular requirements.


import yoda, optparse, sys, math

parser = optparse.OptionParser(usage=__doc__)
parser.add_option('-o', '--output', default='-', dest='OUTPUT_FILE')
parser.add_option('--assume-normalized', action="store_true", default=False, dest='ASSUME_NORMALIZED')
opts, fileargs = parser.parse_args()

## Put the incoming objects into a dict from each path to a list of histos and scalings
analysisobjects_in = {}
for fa in fileargs:
    filename, scale = fa, 1.0
    if ":" in fa:
        try:
            filename, scale = fa.rsplit(":", 1)
            scale = float(scale)
        except:
            sys.stderr.write("Error processing arg '%s' with file:scale format\n" % fa)
    aos = yoda.read(filename)
    for aopath, ao in aos.iteritems():
        ao.setAnnotation("yodamerge_scale", scale)
        analysisobjects_in.setdefault(aopath, []).append(ao)

analysisobjects_out = {}
for p, aos in analysisobjects_in.iteritems():

    ## Check whether normalizations match (possible for histograms via the sumW/integral attrs)
    ## In the absence of better info, we use the norm as a heuristic to change the merging behaviour
    # TODO: Could we just test for a ScaledBy attribute?
    normto = None
    if len(aos) > 1 and all(hasattr(ao, "sumW") for ao in aos):
        ## Check that there are some non-empty input histograms
        nonzero_sumws = [ao.sumW() for ao in aos if ao.sumW() != 0]
        sumw_ref = nonzero_sumws[0] if nonzero_sumws else None
        ## Use the first non-empty histogram (if there is one) as a reference for norm comparisons
        same_norms = sumw_ref and all(abs(ao.sumW()/sumw_ref - 1) < 1e-3 for ao in aos)
        ## Set the normto flag if possible (either because norms match or user-forced)
        if same_norms or opts.ASSUME_NORMALIZED:
            if not all("ScaledBy" in ao.annotations for ao in aos):
                print "WARNING: Abandoning normalized merge of path %s because not all inputs have ScaledBy attributes" % p
            else:
                ## Try to compute a target normalization from the 1/scalefactor-weighted norms of each run
                wtot = sum(1/float(ao.annotation("ScaledBy")) for ao in aos)
                normto = sum(ao.sumW() / float(ao.annotation("ScaledBy")) for ao in aos) / wtot

    ## Check that types match, and just output the first one if they don't
    if not all(type(ao) is type(aos[0]) for ao in aos[1:]):
        print "WARNING: Several types of analysis object found at path %s: cannot be merged"
        analysisobjects_out[p] = aos[0]
        continue

    ## If target should be normalized, undo the normalization scaling now
    if normto:
        for ao in aos:
            ao.scaleW(1/float(ao.annotation("ScaledBy")))

    ## Now that the normalization-identifying heuristic is done, apply user scalings
    ## NB. For histos only: we assume that scatters should be treated like ratios / normalized histos
    for ao in aos:
        if hasattr(ao, "scaleW"):
            scale = float(ao.annotation("yodamerge_scale"))
            ao.scaleW(scale)

    ## Make a copy of the (scaled & unnormalized) first object as the basis for the output
    ao_out = aos[0].clone()
    ao_out.rmAnnotation("yodamerge_scale")


    ## Merge for histograms (including weights, normalization, and user scaling)
    if hasattr(ao_out, "__iadd__"):
        for ao in aos[1:]:
            ao_out += ao
        if normto:
            ao_out.normalize(normto)

    ## Merge for Scatter2D (assuming equal run sizes, and applying user scaling)
    elif type(ao_out) is yoda.Scatter2D:
        msg = "WARNING: Scatter2D %s merge assumes equal run sizes" % p
        if any(float(ao.annotation("yodamerge_scale")) != 1.0 for ao in aos):
             msg += " (+ user scaling)"
        if len(aos) > 1:
            print msg
        npoints = len(ao_out.points)
        for i in xrange(npoints):
            y_i = eyp_i = eym_i = scalesum = 0.0
            for ao in aos:
                scale = float(ao.annotation("yodamerge_scale"))
                scalesum += scale
                y_i += scale * ao.points[i].y
                eyp_i += (scale * ao.points[i].yErrs[0])**2
                eym_i += (scale * ao.points[i].yErrs[1])**2
            y_i /= scalesum
            eyp_i = math.sqrt(eyp_i) / scalesum
            eym_i = math.sqrt(eym_i) / scalesum
            ao_out.points[i].y = y_i
            ao_out.points[i].yErrs = (eyp_i, eym_i)

    ## Other data types (just warn, and write out the first object)
    else:
        print "WARNING: Analysis object %s of type %s cannot be merged" % (p, str(type(ao)))


    ## Put the output AO into the output dict
    analysisobjects_out[p] = ao_out

## Write output
yoda.writeYODA(analysisobjects_out, opts.OUTPUT_FILE)
