#!/usr/bin/env python
# This may need to be /usr/bin/python2.7 on comanche

import argparse, numpy, subprocess, re, sys, os
import numpy, re, sys, subprocess
#import matplotlib.pyplot as plt

# Parse command-line arguments

parser = argparse.ArgumentParser(description = """
Create a set of NWChem input files for the Ar2 dimer, run NWChem on them, and
parse the output for interaction energy.
""")
parser.add_argument("-m", "--min", dest="minR", type=float, default=1.,
                    help="minimum Ar separation in A")
parser.add_argument("-M", "--max", dest="maxR", type=float, default=5.,
                    help="maximum Ar separation in A")
parser.add_argument("-d", "--delta", dest="deltaR", type=float, default=0.5,
                    help="step size for Ar separation in A")
parser.add_argument("-o", "--output", dest="output", type=str, default="output.txt",
                    help="output text file")
parser.add_argument("-np", "--nproc", dest="nproc", type=int, default=1,
                    help="Number of processors to be used.")
parser.add_argument("-b", "--basis", dest="basis", type=str, default="aug-cc-pvtz",
                    help="basis set used in calculation")

args = parser.parse_args()

nwchem_dir = os.environ["NWCHEM_TOP"]
username = os.environ["USER"]
nproc = args.nproc


# I have commented out the uniform grid:
#positions = numpy.arange(args.minR, args.maxR + 0.00001, args.deltaR)
# Instead, use a non-uniform grid specific to your system:
positions = numpy.array([1.7,2.1,2.3,2.5,2.7,3.1,4.0,5.0,10.0,20.0])

print "Basis set ",args.basis
print "positions = ", positions
print "Output to : ",args.output
print "Number of processors : ",nproc
output = open(args.output,'w')


# Conversion factor:
au2cm = 219474.631371
au2kJ = 2625.499

# Construct input files

input_file_template = """
title "NH3..H2O BSSE corrected MP2 interaction energy"

scratch_dir /scratch/HDD_2T/{user}/nwchem

geometry units angstrom "NH3+H2O"
  N       0.00000     0.00000     0.00000
  H1      0.98601     0.00000    -0.30465
  H2     -0.46058     0.84108    -0.38142
  H3     -0.46058    -0.84108    -0.38142
  O       0.00000     0.00000     {RzO}
  H4      0.00000     0.00000     {RzH4}
  H5     -0.96568    -0.00000     {RzH5}
end

geometry units angstrom "NH3+ghost"
  N       0.00000     0.00000     0.00000
  H1      0.98601     0.00000    -0.30465
  H2     -0.46058     0.84108    -0.38142
  H3     -0.46058    -0.84108    -0.38142
  BqO     0.00000     0.00000     {RzO}
  BqH     0.00000     0.00000     {RzH4}
  BqH    -0.96568    -0.00000     {RzH5}
end

geometry units angstrom "ghost+H2O"
  BqN       0.00000     0.00000     0.00000
  BqH       0.98601     0.00000    -0.30465
  BqH      -0.46058     0.84108    -0.38142
  BqH      -0.46058    -0.84108    -0.38142
  O         0.00000     0.00000     {RzO}
  H4        0.00000     0.00000     {RzH4}
  H5       -0.96568    -0.00000     {RzH5}
end

basis
 N   library    {basis}
 H1  library    {basis}
 H2  library    {basis}
 H3  library    {basis}
 O   library    {basis}
 H4  library    {basis}
 H5  library    {basis}
 BqN   library  N  {basis}
 BqH   library  H  {basis}
 BqO   library  O  {basis}
end

set geometry "NH3+H2O"
charge 0
task mp2
unset geometry "NH3+H2O"

scf; vectors atomic; end

set geometry "NH3+ghost"
charge 0
task mp2
unset geometry "NH3+ghost"

scf; vectors atomic; end

set geometry "ghost+H2O"
charge 0
task mp2
unset geometry "ghost+H2O"

"""

# z coordinates for O, H4 and H5 : H2O will be translated along the z-axis
# Units: Angstrom
zO  =  0.0
zH4 = -1.03900 
zH5 =  0.27836

# Regular expressions to find the right lines in the file

scf_re = re.compile("SCF energy\s+(-?[0-9]+.[0-9]+)")
mp2_re = re.compile("Total MP2 energy\s+(-?[0-9]+.[0-9]+)")

# Dictionary to store the energies we find

energies = {}

# Main loop through different position values

for delta_z in positions:
    print "Current position = ",delta_z
    input_file_name = "in-R{0:5.3f}.nw".format(delta_z)
    output_file_name = "out-R{0:5.3f}.out".format(delta_z)

    # Write input file
    Rz_O  = zO + delta_z  
    Rz_H4 = zH4 + delta_z
    Rz_H5 = zH5 + delta_z
    RzO  = """{0:10.6f}""".format(Rz_O)
    RzH4 = """{0:10.6f}""".format(Rz_H4)
    RzH5 = """{0:10.6f}""".format(Rz_H5)
    input_file = open(input_file_name, "w")
    input_file.write(input_file_template.format(RzO=RzO,RzH4=RzH4,RzH5=RzH5,user=username,basis=args.basis))
    input_file.close()

    # Run NWChem

    output_file = open(output_file_name, "w")
    subprocess.call(["mpirun.mpich","-np",str(nproc),"nwchem", input_file_name],
                                                stdout=output_file, stderr=subprocess.STDOUT)
    output_file.close()

    # Read output file back in

    output_file = open(output_file_name, "r")
    scf_energies = []
    mp2_energies = []
    for line in output_file:
        scf_match = scf_re.search(line)
        if scf_match:
            scf_energies.append(float(scf_match.group(1)))
        else:
            mp2_match = mp2_re.search(line)
            if mp2_match:
                mp2_energies.append(float(mp2_match.group(1)))
    output_file.close()

    if not (len(scf_energies) == 3 and len(mp2_energies) == 3):
        print "Apparent error in output file! {:d} SCF energies and {:d} MP2 energies found.".format(
            len(scf_energies), len(mp2_energies))
        sys.exit(1)

    energies[delta_z] = scf_energies + mp2_energies
    scf_dimer    = energies[delta_z][0]
    scf_monomerA = energies[delta_z][1]
    scf_monomerB = energies[delta_z][2]
    mp2_dimer    = energies[delta_z][3]
    mp2_monomerA = energies[delta_z][4]
    mp2_monomerB = energies[delta_z][5]
    scf_interaction = scf_dimer - scf_monomerA - scf_monomerB
    mp2_interaction = mp2_dimer - mp2_monomerA - mp2_monomerB
    print "Eint[SCF] = ",scf_interaction*au2kJ
    print "Eint[MP2] = ",mp2_interaction*au2kJ

# prepare plots

scf_dimer = numpy.array([energies[p][0] for p in positions])
scf_monomerA = numpy.array([energies[p][1] for p in positions])
scf_monomerB = numpy.array([energies[p][2] for p in positions])
mp2_dimer = numpy.array([energies[p][3] for p in positions])
mp2_monomerA = numpy.array([energies[p][4] for p in positions])
mp2_monomerB = numpy.array([energies[p][5] for p in positions])

scf_interaction = scf_dimer - scf_monomerA - scf_monomerB
mp2_interaction = mp2_dimer - mp2_monomerA - mp2_monomerB

# convert to kJ/mol
scf_interaction = au2kJ*scf_interaction
mp2_interaction = au2kJ*mp2_interaction

output.write(('%10s '*9 + '\n') % 
             ("   R    ","SCF monoA","SCF monoB","SCF dimr","SCF int ",
                         "MP2 monoA","MP2 monoB","MP2 dimr","MP2 int "))

for i in range(len(positions)):
    output.write(('%10.6f '*9 + '\n') %
    (positions[i],
    scf_monomerA[i], scf_monomerB[i], scf_dimer[i], scf_interaction[i],
    mp2_monomerA[i], mp2_monomerB[i], mp2_dimer[i], mp2_interaction[i]))

output.close()