#! /usr/bin/env python
from __future__ import division
import os, sys, pprint, argparse, re
from string import strip, Template

# add path to the ufo conversion modules
modulepath = os.path.join("@PKGLIBDIR@",'python')
sys.path.append(modulepath)

from ufo2peg import *

# set up the option parser for command line input 
parser = argparse.ArgumentParser(
    description='Create Herwig++ model files from Feynrules UFO input.'
)
parser.add_argument(
    'ufodir', 
    metavar='UFO_directory', 
    help='the UFO model directory'
)
parser.add_argument(
    '-v', '--verbose', 
    action="store_true", 
    help="print verbose output"
)
parser.add_argument(
    '-n','--name', 
    default="FRModel", 
    help="set custom nametag for the model"
)
parser.add_argument(
    '--ignore-skipped', 
    action="store_true", 
    help="ignore skipped vertices and produce output anyway"
)

vertex_skipped = False

args = parser.parse_args()

def should_print():
    return not vertex_skipped or args.ignore_skipped

modeldir = args.ufodir.rstrip('/')
modelpath, module = os.path.split(modeldir)
if modelpath:
    sys.path.append(os.path.abspath(modelpath))

FR = __import__(module)

##################################################
##################################################

# get the Model name from the arguments
modelname = args.name
libname = modelname + '.so'



# define arrays and variables     
#allplist = ""
parmdecls = []
parmgetters = []
parmconstr = []
doinit = []

paramstoreplace_ = []
paramstoreplace_expressions_ = []

# get external parameters for printing
parmsubs = dict( [ (p.name, p.value) 
                   for p in FR.all_parameters 
                   if p.nature == 'external' ] ) 

# evaluate python cmath
def evaluate(x):
    import cmath
    return eval(x, 
                {'cmath':cmath,
                 'complexconjugate':FR.function_library.complexconjugate}, 
                parmsubs)


## get internal params into arrays
internal = ( p for p in FR.all_parameters 
             if p.nature == 'internal' )

#paramstoreplaceEW_ = []
#paramstoreplaceEW_expressions_ = []

# calculate internal parameters
for p in internal:
    parmsubs.update( { p.name : evaluate(p.value) } )
#    if 'aS' in p.value and p.name != 'aS':
#        paramstoreplace_.append(p.name)
#        paramstoreplace_expressions_.append(p.value)
#    if 'aEWM1' in p.value and p.name != 'aEWM1':
#        paramstoreplaceEW_.append(p.name)
#        paramstoreplaceEW_expressions_.append(p.value)

        
    
# more arrays used for substitution in templates 
paramsforstream = []
parmmodelconstr = []

# loop over parameters and fill in template stuff according to internal/external and complex/real
# WARNING: Complex external parameter input not tested!
if args.verbose:
    print 'verbose mode on: printing all parameters'
    print '-'*60
    paramsstuff = ('name', 'expression', 'default value', 'nature')
    pprint.pprint(paramsstuff)



interfacedecl_T = """\
static Parameter<{modelname}, {type}> interface{pname}
  ("{pname}",
   "The interface for parameter {pname}",
   &{modelname}::{pname}_, {value}, 0, 0,
   false, false, Interface::nolimits);
"""

interfaceDecls = []
modelparameters = {}

for p in FR.all_parameters:
    value = parmsubs[p.name]

    if p.type == 'real':
        assert( value.imag < 1.0e-16 )
        value = value.real
        if p.nature == 'external':
            interfaceDecls.append( 
                interfacedecl_T.format(modelname=modelname,
                                       pname=p.name,
                                       value=value,
                                       type=typemap(p.type)) 
            )
            if hasattr(p,'lhablock'):
                lhalabel = '{lhablock}_{lhacode}'.format( lhablock=p.lhablock, lhacode='_'.join(map(str,p.lhacode)) )
                parmmodelconstr.append('set %s:%s ${%s}' % (modelname, p.name, lhalabel))
                modelparameters[lhalabel] = value
                parmsubs[p.name] = lhalabel
            else:
                parmmodelconstr.append('set %s:%s %s' % (modelname, p.name, value))
                parmsubs[p.name] = value
            parmconstr.append('%s_(%s)' % (p.name, value))
        else :
            parmconstr.append('%s_()' % p.name)
            parmsubs[p.name] = value
    elif p.type == 'complex':
        value = complex(value)
        if p.nature == 'external':
#
# TODO: WE DO NOT HAVE COMPLEX INTERFACES IN THEPEG (yet?)
#
#            interfaceDecls.append( 
#                interfacedecl_T.format(modelname=modelname,
#                                       pname=p.name,
#                                       value='Complex(%s,%s)'%(value.real,value.imag),
#                                       type=typemap(p.type)) 
#            )
#
#            parmmodelconstr.append('set %s:%s (%s,%s)' % (modelname, p.name, value.real, value.imag))
            parmconstr.append('%s_(%s,%s)' % (p.name, value.real, value.imag))
        else :
            parmconstr.append('%s_(%s,%s)' % (p.name, 0.,0.))
        parmsubs[p.name] = value
    else:
        raise Exception('Unknown data type "%s".' % p.type)

    parmdecls.append('  %s %s_;' % (typemap(p.type), p.name))
    parmgetters.append('  %s %s() const { return %s_; }' % (typemap(p.type),p.name, p.name))
    paramsforstream.append('%s_' % p.name)
    expression, symbols = 'return %s_' % p.name, None
    if p.nature != 'external':
        expression, symbols = py2cpp(p.value)
        text = add_brackets(expression, symbols)
        text=text.replace('()()','()')
        #parmgetters.append('  %s %s() const { return %s; }' % (typemap(p.type),p.name, text))
        doinit.append('   %s_ = %s;'  % (p.name, text) )

### special treatment
#    if p.name == 'aS':
#        expression = '0.25 * sqr(strongCoupling(q2)) / Constants::pi'
#    elif p.name == 'aEWM1':
#        expression = '4.0 * Constants::pi / sqr(electroMagneticCoupling(q2))'
#    elif p.name == 'Gf':
#        expression = 'generator()->standardModel()->fermiConstant() * GeV2'
#    elif p.name == 'MZ':
#        expression = 'getParticleData(ThePEG::ParticleID::Z0)->mass() / GeV'


    if args.verbose:
        pprint.pprint((p.name,p.value, value, p.nature))

parmtextsubs = { 'parmgetters' : '\n'.join(parmgetters),
                 'parmdecls' : '\n'.join(parmdecls),
                 'parmconstr' : ': %s' % ','.join(parmconstr),
                 'getters' : '',
                 'decls' : '',
                 'addVertex' : '',
                 'doinit' : '\n'.join(doinit),
                 'ostream' : '<< '.join(paramsforstream),
                 'istream' : '>> '.join(paramsforstream),
                 'refs' : '',
                 'parmextinter': ''.join(interfaceDecls),
                 'ModelName': modelname
                 }

##################################################
##################################################
##################################################


# get vertex template
VERTEX = getTemplate('Vertex.cc')
VERTEXCLASS = getTemplate('Vertex_class')
VERTEXHEADER = """\
#include "ThePEG/Helicity/Vertex/{spindirectory}/{lorentztag}Vertex.h"
"""

def write_vertex_file(subs):
    newname = '%s_Vertices_%03d.cc' % (subs['ModelName'],subs['vertexnumber'])
    writeFile( newname, VERTEX.substitute(subs) )

if args.verbose:
    print 'verbose mode on: printing all vertices'
    print '-'*60
    labels = ('vertex', 'particles', 'Lorentz', 'C_L', 'C_R', 'norm')
    pprint.pprint(labels)

### initial pass to find global sign
def global_sign():
    return 1.0
    for v in FR.all_vertices:
        pids = sorted([ p.pdg_code for p in v.particles ])
        if pids != [-11,11,22]: continue
        coup = v.couplings
        assert( len(coup) == 1 )
        val = coup.values()[0].value
        val = evaluate(val)
        assert( val.real == 0 )
        return 1 if val.imag > 0 else -1

vertexclasses, vertexheaders = [], set()

ONE_EACH = True
if ONE_EACH:
    all_vertices = FR.all_vertices
else:
    all_vertices = collapse_vertices(FR.all_vertices)

globalsign = global_sign()
for vertexnumber,vertex in enumerate(all_vertices,1):

    lorentztag = unique_lorentztag(vertex)
    
    vertex.herwig_skip_vertex = False

    # remove vertices involving ghost fields
    if 'U' in lorentztag:
        vertex.herwig_skip_vertex = True
        continue


    gsnames = ['goldstone',
           'goldstoneboson',
           'GoldstoneBoson']

    for p in vertex.particles:
        def gstest(name):
            try:
                return getattr(p,name)
            except AttributeError:
                return False

        if any(map(gstest, gsnames)):
            vertex.herwig_skip_vertex = True
            break

    if vertex.herwig_skip_vertex: 
        continue


    lfactors = { 
        'FFV'  : '-complex(0,1)',  # ok
        'VVV'  : 'complex(0,1)',   # changed to fix ttbar
        'VVVV' : 'complex(0,1)',
        'VVS'  : '-complex(0,1)',
        'VSS'  : '-complex(0,1)', # changed to minus to fix dL ->x1 W- d
        'SSS'  : '-complex(0,1)',  # ok
        'VVSS' : '-complex(0,1)',  # ok
        'VVT'  : 'complex(0,2)',
        'VVVT' : '-complex(0,2)',
        'SSSS' : '-complex(0,1)',  # ok
        'FFS'  : '-complex(0,1)',  # ok
        'SST'  : 'complex(0,2)',
        'FFT'  : '-complex(0,8)',
        'FFVT' : '-complex(0,4)',
    }

    try:
        lf = lfactors[lorentztag]
    except KeyError:
        msg = 'Warning: Lorentz structure {tag} ( {ps} ) in {name} ' \
              'is not supported.\n'.format(tag=lorentztag, name=vertex.name, 
                                           ps=' '.join(map(str,vertex.particles)))
        sys.stderr.write(msg)
        vertex.herwig_skip_vertex = True
        vertex_skipped = True
        continue

    ### Particle ids
    if ONE_EACH:
        plistarray = [ ','.join([ str(p.pdg_code) for p in vertex.particles ]) ]
    else:
        plistarray = [ ','.join([ str(p.pdg_code) for p in pl ])
                       for pl in vertex.particles_list ]

    try:
        L,pos = colors(vertex)
        cf = colorfactor(vertex,L,pos)
    except SkipThisVertex:
        vertex.herwig_skip_vertex = True
        vertex_skipped = True
        continue
    
    ### classname
    classname = 'V_%s' % vertex.name

    ### parse couplings
    unique_qcd = CheckUnique()
    unique_qed = CheckUnique()
    
    coup_left  = []
    coup_right = []
    coup_norm = []
    
    if ONE_EACH:
        items = vertex.couplings.iteritems()
    else:
        items = vertex.couplings

    try:
        for (color_idx,lorentz_idx),coupling in items:

            qcd, qed = qcd_qed_orders(vertex, coupling)
            unique_qcd( qcd )
            unique_qed( qed )

            L = vertex.lorentz[lorentz_idx]
            prefactors = '(%s) * (%s) * (%s)' \
                            % (globalsign**(len(lorentztag)-2),lf,cf[color_idx])

            ordering = ''
            if lorentztag in ['FFS','FFV']:
                left,right = parse_lorentz(L.structure)
                if left:
                    coup_left.append('(%s) * (%s) * (%s)' % (prefactors,left,coupling.value))
                if right:
                    coup_right.append('(%s) * (%s) * (%s)' % (prefactors,right,coupling.value))
                if lorentztag == 'FFV':
                    ordering = ('if(p1->id()!=%s) {Complex ltemp=left(), rtemp=right(); left(-rtemp); right(-ltemp);}' 
                                % vertex.particles[0].pdg_code)
            elif 'T' in lorentztag :
                all_coup, left_coup, right_coup, ordering = \
                    tensorCouplings(vertex,coupling,prefactors,L,lorentztag,pos)
                coup_norm  += all_coup
                coup_left  += left_coup
                coup_right += right_coup
            else:
                if lorentztag == 'VSS':
                    if L.structure == 'P(1,3) - P(1,2)':
                        prefactors += ' * (-1)'
                    ordering = 'if(p2->id()!=%s){norm(-norm());}' \
                                   % vertex.particles[1].pdg_code
                elif lorentztag == 'VVVV':
                    if qcd==2:
                        ordering = 'setType(1);\nsetOrder(0,1,2,3);'
                    else:
                        ordering, factor = EWVVVVCouplings(vertex,L)
                        prefactors += ' * (%s)' % factor
                elif lorentztag == 'VVV':
                    if len(pos[8]) != 3:
                        ordering = VVVordering(vertex)

                if type(coupling) is not list:
                    value = coupling.value
                else:
                    value = "("
                    for coup in coupling :
                        value += '+(%s)' % coup.value
                    value +=")"
                coup_norm.append('(%s) * (%s)' % (prefactors,value))


            #print 'Colour  :',vertex.color[color_idx]
            #print 'Lorentz %s:'%L.name, L.spins, L.structure
            #print 'Coupling %s:'%C.name, C.value, '\nQED=%s'%qed, 'QCD=%s'%qcd
            #print '---------------'
    except SkipThisVertex:
        vertex.herwig_skip_vertex = True
        vertex_skipped = True
        continue


    leftcontent  = ' + '.join(coup_left)  if coup_left  else '0'
    rightcontent = ' + '.join(coup_right) if coup_right else '0'
    normcontent  = ' + '.join(coup_norm)  if coup_norm  else '1'

#    #print 'Left:',leftcontent
#    #print 'Right:',rightcontent
#    #print 'Norm:',normcontent
#    #print '---------------'


#    ### do we need left/right?
    symbols = set()
    if 'FF' in lorentztag and lorentztag != "FFT":
#        leftcalc = aStoStrongCoup(py2cpp(leftcontent)[0], paramstoreplace_, paramstoreplace_expressions_)
#        rightcalc = aStoStrongCoup(py2cpp(rightcontent)[0], paramstoreplace_, paramstoreplace_expressions_)
        leftcalc, sym = py2cpp(leftcontent)
        symbols |= sym
        rightcalc, sym = py2cpp(rightcontent)
        symbols |= sym
        left = 'left(' + leftcalc + ');'
        right = 'right(' + rightcalc + ');'
    else:
        left = ''
        right = ''
        leftcalc = ''
        rightcalc = ''
        

#    normcalc = aStoStrongCoup(py2cpp(normcontent)[0], paramstoreplace_, paramstoreplace_expressions_)
    normcalc, sym = py2cpp(normcontent)
    symbols |= sym
    # UFO is GeV by default (?)
    if lorentztag in ['VVS','SSS']:
        normcalc = '(%s) * GeV / UnitRemoval::E' % normcalc
    elif lorentztag in ['FFT','VVT', 'SST', 'FFVT', 'VVVT' ]:
        normcalc = 'Complex((%s) / GeV * UnitRemoval::E)' % normcalc
    norm = 'norm(' + normcalc + ');'

    # define unkown symbols from the model
    symboldefs = [ def_from_model(FR,s) for s in symbols ]
        
    couplingptrs = [',tcPDPtr']*len(lorentztag)
    if lorentztag == 'VSS':
        couplingptrs[1] += ' p2'
    elif lorentztag == 'FFV':
        couplingptrs[0] += ' p1'
    elif lorentztag == 'VVV' and ordering != '':
        couplingptrs[0] += ' p1'
        couplingptrs[1] += ' p2'
        couplingptrs[2] += ' p3'
    elif lorentztag == 'VVVT' and ordering != '':
        couplingptrs[0] += ' p1'
        couplingptrs[1] += ' p2'
        couplingptrs[2] += ' p3'
    elif lorentztag == 'VVVV' and qcd != 2:
        couplingptrs[0] += ' p1'
        couplingptrs[1] += ' p2'
        couplingptrs[2] += ' p3'
        couplingptrs[3] += ' p4'
        
    ### assemble dictionary and fill template
    subs = { 'lorentztag' : lorentztag,                   # ok
             'classname'  : classname,            # ok
             'symbolrefs' : '\n    '.join(symboldefs),
             'left'       : left,                 # doesn't always exist in base
             'right'      : right,                 # doesn't always exist in base 
             'norm'      : norm,                 # needs norm, too

             #################### need survey which different setter methods exist in base classes

             'addToPlist' : '\n'.join([ 'addToList(%s);'%s for s in plistarray]),
             'parameters' : '',
             'setCouplings' : '',
             'qedorder'   : qed,
             'qcdorder' : qcd,
             'couplingptrs' : ''.join(couplingptrs),
             'spindirectory' : spindirectory(lorentztag),
             'ModelName' : modelname,
             'ordering' : ordering,
             }             # ok

    if args.verbose:
        print '-'*60
        pprint.pprint(( classname, plistarray, leftcalc, rightcalc, normcalc ))
        
    vertexclasses.append(VERTEXCLASS.substitute(subs))
    vertexheaders.add(VERTEXHEADER.format(**subs))

    WRAP = 25
    if vertexnumber % WRAP == 0:
        write_vertex_file({'classname':classname,
                           'vertexnumber' : vertexnumber//WRAP,
                           'vertexclasses' : '\n'.join(vertexclasses),
                           'vertexheaders' : ''.join(vertexheaders),
                           'ModelName' : modelname})
        vertexclasses = []
        vertexheaders = set()


if not should_print():
    sys.stderr.write(
"""
Error: The conversion was unsuccessful, some vertices could not be
       generated. If you think the missing vertices are not important 
       and want to go ahead anyway, use --ignore-skipped. 
       Herwig++ may not give correct results, though.

"""
    )
    
    sys.exit(1)
        
if vertexclasses:
    write_vertex_file({'classname':classname,
                       'vertexnumber' : vertexnumber//WRAP + 1,
                       'vertexclasses' : '\n'.join(vertexclasses),
                       'vertexheaders' : ''.join(vertexheaders),
                       'ModelName' : modelname})
        
print '='*60

##################################################
##################################################
##################################################

vertexline = """\
create Herwig::{modelname}V_{vname} /Herwig/{modelname}/V_{vname}
insert {modelname}:ExtraVertices 0 /Herwig/{modelname}/V_{vname}
"""

def get_vertices():
    vlist = ['library %s\n' % libname]
    for v in all_vertices:
        if v.herwig_skip_vertex: continue
        vlist.append( vertexline.format(modelname=modelname, vname=v.name) )
    vlist.append('insert {modelname}:ExtraVertices 0 /Herwig/{modelname}/V_GenericHPP\n'.format(modelname=modelname) )
    vlist.append('insert {modelname}:ExtraVertices 0 /Herwig/{modelname}/V_GenericHGG\n'.format(modelname=modelname) )
    return ''.join(vlist)


plist, names = thepeg_particles(FR,parmsubs,modelname,modelparameters)

particlelist = [
    "# insert HPConstructor:Outgoing 0 /Herwig/{n}/Particles/{p}".format(n=modelname,p=p)
    for p in names
]
# make the first one active to have something runnable in the example .in file
particlelist[0] = particlelist[0][2:]
particlelist = '\n'.join(particlelist)

modelfilesubs = { 'plist' : plist,
                  'vlist' : get_vertices(),
                  'setcouplings': '\n'.join(parmmodelconstr),
                  'ModelName': modelname
                  }

# write the files from templates according to the above subs
if should_print():
    MODEL_HWIN = getTemplate('LHC-FR.in')
    MODEL_H  = getTemplate('Model.h')
    MODEL_CC = getTemplate('Model.cc')
    MODELINFILE = getTemplate('FR.model')
    writeFile( 'LHC-%s.in' % modelname, 
               MODEL_HWIN.substitute({ 'ModelName' : modelname,
                                       'Particles' : particlelist })
    )

    modeltemplate = Template( MODELINFILE.substitute(modelfilesubs) )

    writeFile( '%s.h' % modelname,      MODEL_H.substitute(parmtextsubs) )
    writeFile( '%s.cc' % modelname,     MODEL_CC.substitute(parmtextsubs) )
    writeFile( modelname +'.template', modeltemplate.template )
    writeFile( modelname +'.model', modeltemplate.substitute( modelparameters ) )
    # copy the Makefile-FR to current directory,
    # replace with the modelname for compilation
    with open(os.path.join(modulepath,'Makefile-FR'),'r') as orig:
        with open('Makefile','w') as dest:
            dest.write(orig.read().replace("FeynrulesModel.so", libname))


print 'finished generating model:\t', modelname
print 'model directory:\t\t', args.ufodir
print 'generated:\t\t\t', len(FR.all_vertices), 'vertices'
print '='*60
print 'library:\t\t\t', libname
print 'input file:\t\t\t', 'LHC-' + modelname +'.in'
print 'model file:\t\t\t', modelname +'.model'
print '='*60
print """\
To complete the installation, compile by typing "make".
An example input file is provided as LHC-FRModel.in, 
you'll need to change the required particles in there.
"""
print 'DONE!'
print '='*60

