from __future__ import print_function
from simtk.openmm import app
import simtk.openmm as mm
from simtk import unit
from sys import stdout, exit

# this script was generated by openmm-builder. to customize it further,
# you can save the file to disk and edit it with your favorite editor.
##########################################################################

##########################################################################
##################### User-Defined Options ###############################
##########################################################################
# Set the temperature and (for NPT systems) pressure of the system. If npt ==
# True, an NPT simulation will be performed; otherwise, the system will run at
# NVT, and the input pressure is irrelevant
temperature = 288*unit.kelvin
pressure = 50.688*unit.bar
npt = False

# This string should pertain to the file containing your force field:
iforcefield = 'sapt_co2.xml'
# This string represents the pdb file containing your system's geometry and
# (possibly) topology
ipdb = 'co2.pdb'

# Choose the long-range cutoffs for the system
two_body_cutoff = 1.1*unit.nanometers
three_body_cutoff = 0.7*unit.nanometers
##########################################################################
##########################################################################


##########################################################################
############# Energy Decomposition Functions #############################
##########################################################################
# OpenMM allows for an energy decomposition scheme on a force-by-force basis.
# This can sometimes be useful for debugging.
def forcegroupify(system):
    forcegroups = {}
    for j in range(system.getNumForces()):
        force = system.getForce(j)
        force.setForceGroup(j)
        forcegroups[force] = j
    return forcegroups

def getEnergyDecomposition(context, forcegroups):
    energies = {}
    for f, j in forcegroups.items():
        energies[f] = context.getState(getEnergy=True, 
                groups=2**j).getPotentialEnergy()
    return energies

def getForceDecomposition(context, forcegroups):
    forces = {}
    for f, j in forcegroups.items():
        forces[f] = context.getState(getForces=True,
groups=2**j).getForces()
    return forces
##########################################################################
##########################################################################


##########################################################################
############### Main Script ##############################################
##########################################################################
# Load system topology and positions
app.topology.Topology.loadBondDefinitions('residues.xml')
print('Loaded bond definitions ')
pdb = app.PDBFile(ipdb)
print('Loaded PDB file')
# Input force field
forcefield = app.ForceField(iforcefield)
print('Loaded force field')

# Add extra particle positions (drude oscillators and/or virtual sites) to the pdb file
model = app.modeller.Modeller(pdb.topology, pdb.positions)
model_topology = model.getTopology()
model_positions = model.getPositions()
model.addExtraParticles(forcefield)

# Create system from force field
system = forcefield.createSystem(model_topology, 
                nonbondedMethod=app.NoCutoff,
                nonbondedCutoff=two_body_cutoff,
                #constraints=app.AllBonds,
                constraints='None',
                polarization='mutual',
                ewaldErrorTolerance=0.0005)
print('Created system')

# Set distance cutoffs, constraints, and other force-specific options for each
# force we might encounter
for force in system.getForces():
    if isinstance(force, mm.CustomHbondForce):
        force.setNonbondedMethod(mm.CustomHbondForce.CutoffPeriodic) 
        force.setCutoffDistance(two_body_cutoff)
    elif isinstance(force, mm.CustomNonbondedForce):
        force.setNonbondedMethod(mm.CustomNonbondedForce.CutoffPeriodic) 
        force.setCutoffDistance(two_body_cutoff)
        force.setUseLongRangeCorrection(True)
    elif isinstance(force, mm.AmoebaMultipoleForce):
        force.setNonbondedMethod(mm.AmoebaMultipoleForce.PME) 
    elif isinstance(force, mm.NonbondedForce):
        force.setNonbondedMethod(mm.NonbondedForce.LJPME) 
        force.setCutoffDistance(two_body_cutoff)
    elif isinstance(force, mm.CustomManyParticleForce):
        force.setNonbondedMethod(mm.CustomManyParticleForce.CutoffPeriodic) 
        force.setCutoffDistance(three_body_cutoff)
    else:
        pass

# The integrator tells us the system temperature, the collision rate, and the
# timestep. The defaults below are good starting values for flexible
# multipolar systems, but it is always important to check the effects of the
# timestep and collision rate
integrator = mm.LangevinIntegrator(
                # system temperature:
                temperature, 
                # collision rate
                1.0/unit.picoseconds, 
                # timestep:
                0.5*unit.femtoseconds)
integrator.setConstraintTolerance(0.00001)
if npt:
    # Add a bartostat to NPT simulations
    system.addForce(
            mm.MonteCarloBarostat(pressure, temperature, 25))

# Run the program on the CPU platform
## platform = mm.Platform.getPlatformByName('CPU')
## properties={}
# If desired, the CUDA platform can be used be uncommenting the following
# lines:
platform = mm.Platform.getPlatformByName('CUDA')
properties = {'CudaPrecision': 'mixed'}
properties["DeviceIndex"] = "1"

# Set up the simulation. If all lines in this section are uncommented, it will
# also print an energy decomposition (useful for debugging):
forcegroups = forcegroupify(system)
simulation = app.Simulation(model_topology, system, integrator, platform,properties)
simulation.context.setPositions(model_positions)
## 
## ener = 0.0*unit.kilojoule/unit.mole
## energy = getEnergyDecomposition(simulation.context,forcegroups)
## for k,v in energy.items():
##     print(k,v)
##     ener += v
## print(ener)
## print('----')
## position = simulation.context.getState(getPositions=True).getPositions()
## app.PDBFile.writeFile(simulation.topology, position, open('beforemin.pdb', 'w'))


# Set simulation temperature and data reporters
simulation.context.setVelocitiesToTemperature(temperature)
simulation.reporters.append(app.PDBReporter('output.pdb', 1000))
simulation.reporters.append(app.StateDataReporter('output.out', 100, step=True, 
        potentialEnergy=True, kineticEnergy=True, temperature=True,
        density=True, volume=True, speed=True, separator='\t'))

# Run the simulation
print('Running...')
for i in range(5000):
    simulation.step(2000)
print('Done!')
