////////////////////////////////////////////////////////////////////////////////////////
//
//  Copyright 2025 OVITO GmbH, Germany
//
//  This file is part of OVITO (Open Visualization Tool).
//
//  OVITO is free software; you can redistribute it and/or modify it either under the
//  terms of the GNU General Public License version 3 as published by the Free Software
//  Foundation (the "GPL") or, at your option, under the terms of the MIT License.
//  If you do not alter this notice, a recipient may use your version of this
//  file under either the GPL or the MIT License.
//
//  You should have received a copy of the GPL along with this program in a
//  file LICENSE.GPL.txt.  You should have received a copy of the MIT License along
//  with this program in a file LICENSE.MIT.txt
//
//  This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
//  either express or implied. See the GPL or the MIT License for the specific language
//  governing rights and limitations.
//
////////////////////////////////////////////////////////////////////////////////////////

#include <ovito/particles/Particles.h>
#include <ovito/particles/util/CutoffNeighborFinder.h>
#include <ovito/particles/objects/BondsVis.h>
#include <ovito/particles/objects/ParticleType.h>
#include <ovito/stdobj/simcell/SimulationCell.h>
#include <ovito/core/dataset/DataSet.h>
#include <ovito/core/dataset/pipeline/ModificationNode.h>
#include <ovito/core/utilities/concurrent/ParallelFor.h>
#include <ovito/core/utilities/units/UnitsManager.h>
#include "CreateBondsModifier.h"

#include <boost/range/numeric.hpp>

namespace Ovito {

IMPLEMENT_CREATABLE_OVITO_CLASS(CreateBondsModifier);
OVITO_CLASSINFO(CreateBondsModifier, "DisplayName", "Create bonds");
OVITO_CLASSINFO(CreateBondsModifier, "Description", "Creates bonds between particles.");
OVITO_CLASSINFO(CreateBondsModifier, "ModifierCategory", "Visualization");
DEFINE_PROPERTY_FIELD(CreateBondsModifier, cutoffMode);
DEFINE_PROPERTY_FIELD(CreateBondsModifier, uniformCutoff);
DEFINE_PROPERTY_FIELD(CreateBondsModifier, pairwiseCutoffs);
DEFINE_PROPERTY_FIELD(CreateBondsModifier, minimumCutoff);
DEFINE_PROPERTY_FIELD(CreateBondsModifier, vdwPrefactor);
DEFINE_PROPERTY_FIELD(CreateBondsModifier, onlyIntraMoleculeBonds);
DEFINE_PROPERTY_FIELD(CreateBondsModifier, skipHydrogenHydrogenBonds);
DEFINE_PROPERTY_FIELD(CreateBondsModifier, autoDisableBondDisplay);
DEFINE_REFERENCE_FIELD(CreateBondsModifier, bondType);
DEFINE_REFERENCE_FIELD(CreateBondsModifier, bondsVis);
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, cutoffMode, "Cutoff mode");
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, uniformCutoff, "Cutoff radius");
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, pairwiseCutoffs, "Pair-wise cutoffs");
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, minimumCutoff, "Lower cutoff");
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, vdwPrefactor, "VdW prefactor");
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, onlyIntraMoleculeBonds, "Suppress inter-molecular bonds");
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, bondType, "Bond type");
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, bondsVis, "Visual element");
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, skipHydrogenHydrogenBonds, "Don't generate H-H bonds");
SET_PROPERTY_FIELD_LABEL(CreateBondsModifier, autoDisableBondDisplay, "Auto-disable bond display");
SET_PROPERTY_FIELD_UNITS_AND_MINIMUM(CreateBondsModifier, uniformCutoff, WorldParameterUnit, 0);
SET_PROPERTY_FIELD_UNITS_AND_MINIMUM(CreateBondsModifier, minimumCutoff, WorldParameterUnit, 0);
SET_PROPERTY_FIELD_UNITS_AND_MINIMUM(CreateBondsModifier, vdwPrefactor, PercentParameterUnit, 0);

/******************************************************************************
* Constructor.
******************************************************************************/
void CreateBondsModifier::initializeObject(ObjectInitializationFlags flags)
{
    Modifier::initializeObject(flags);

    if(!flags.testFlag(ObjectInitializationFlag::DontInitializeObject)) {
        // Create the bond type that will be assigned to the newly created bonds.
        setBondType(OORef<BondType>::create(flags));
        bondType()->initializeType(OwnerPropertyRef(&Bonds::OOClass(), Bonds::TypeProperty));

        // Create the vis element for rendering the bonds generated by the modifier.
        setBondsVis(OORef<BondsVis>::create(flags));
    }
}

/******************************************************************************
* Is called when a RefTarget referenced by this object generated an event.
******************************************************************************/
bool CreateBondsModifier::referenceEvent(RefTarget* source, const ReferenceEvent& event)
{
    if(source == bondsVis() && event.type() == ReferenceEvent::TargetEnabledOrDisabled && bondsVis()->isEnabled()) {
        // If the user explicitly re-enables the display of bonds, then the modifier should stop turning it off
        // again in the future.
        setAutoDisableBondDisplay(false);
    }
    return Modifier::referenceEvent(source, event);
}

/******************************************************************************
* Asks the modifier whether it can be applied to the given input data.
******************************************************************************/
bool CreateBondsModifier::OOMetaClass::isApplicableTo(const DataCollection& input) const
{
    return input.containsObject<Particles>();
}

/******************************************************************************
* Sets the cutoff radius for a pair of particle types.
******************************************************************************/
void CreateBondsModifier::setPairwiseCutoff(const QVariant& typeA, const QVariant& typeB, FloatType cutoff)
{
    PairwiseCutoffsList newList = pairwiseCutoffs();
    if(cutoff > 0) {
        newList[qMakePair(typeA, typeB)] = cutoff;
        newList[qMakePair(typeB, typeA)] = cutoff;
    }
    else {
        newList.remove(qMakePair(typeA, typeB));
        newList.remove(qMakePair(typeB, typeA));
    }
    setPairwiseCutoffs(std::move(newList));
}

/******************************************************************************
* Returns the pair-wise cutoff radius for a pair of particle types.
******************************************************************************/
FloatType CreateBondsModifier::getPairwiseCutoff(const QVariant& typeA, const QVariant& typeB) const
{
    auto iter = pairwiseCutoffs().find(qMakePair(typeA, typeB));
    if(iter != pairwiseCutoffs().end()) return iter.value();
    iter = pairwiseCutoffs().find(qMakePair(typeB, typeA));
    if(iter != pairwiseCutoffs().end()) return iter.value();
    return 0;
}

/******************************************************************************
* This method is called by the system when the modifier has been inserted
* into a pipeline.
******************************************************************************/
void CreateBondsModifier::initializeModifier(const ModifierInitializationRequest& request)
{
    Modifier::initializeModifier(request);

    int bondTypeId = 1;
    const PipelineFlowState& input = request.modificationNode()->evaluateInput(request).blockForResult();
    if(const Particles* particles = input.getObject<Particles>()) {
        // Adopt the upstream BondsVis object if there already is one.
        // Also choose a unique numeric bond type ID, which does not conflict with any existing bond type.
        if(const Bonds* bonds = particles->bonds()) {
            if(BondsVis* bondsVis = bonds->visElement<BondsVis>()) {
                setBondsVis(bondsVis);
            }
            if(const Property* bondTypeProperty = bonds->getProperty(Bonds::TypeProperty)) {
                bondTypeId = bondTypeProperty->generateUniqueElementTypeId();
            }
        }

        // Initialize the pair-wise cutoffs based on the van der Waals radii of the particle types.
        if(this_task::isInteractive() && pairwiseCutoffs().empty()) {
            if(const Property* typeProperty = particles->getProperty(Particles::TypeProperty)) {
                PairwiseCutoffsList cutoffList;
                for(const ElementType* type1 : typeProperty->elementTypes()) {
                    if(const ParticleType* ptype1 = dynamic_object_cast<ParticleType>(type1)) {
                        if(ptype1->vdwRadius() > 0.0) {
                            QVariant key1 = ptype1->name().isEmpty() ? QVariant::fromValue(ptype1->numericId()) : QVariant::fromValue(ptype1->name());
                            for(const ElementType* type2 : typeProperty->elementTypes()) {
                                if(const ParticleType* ptype2 = dynamic_object_cast<ParticleType>(type2)) {
                                    if(ptype2->vdwRadius() > 0.0 && (ptype1->name() != QStringLiteral("H") || ptype2->name() != QStringLiteral("H"))) {
                                        // Note: Prefactor 0.6 has been adopted from VMD source code.
                                        FloatType cutoff = 0.6 * (ptype1->vdwRadius() + ptype2->vdwRadius());
                                        QVariant key2 = ptype2->name().isEmpty() ? QVariant::fromValue(ptype2->numericId()) : QVariant::fromValue(ptype2->name());
                                        cutoffList[qMakePair(key1, key2)] = cutoff;
                                    }
                                }
                            }
                        }
                    }
                }
                setPairwiseCutoffs(std::move(cutoffList));
            }
        }
    }
    if(bondType() && bondType()->numericId() == 0) {
        bondType()->setNumericId(bondTypeId);
        bondType()->initializeType(OwnerPropertyRef(&Bonds::OOClass(), Bonds::TypeProperty));
    }
}

/******************************************************************************
* Looks up a particle type in the type list based on the name or the numeric ID.
******************************************************************************/
const ElementType* CreateBondsModifier::lookupParticleType(const Property* typeProperty, const QVariant& typeSpecification)
{
    if(typeSpecification.typeId() == QMetaType::Int) {
        return typeProperty->elementType(typeSpecification.toInt());
    }
    else {
        const QString& name = typeSpecification.toString();
        for(const ElementType* type : typeProperty->elementTypes())
            if(type->nameOrNumericId() == name)
                return type;
        return nullptr;
    }
}

/******************************************************************************
* Modifies the input data.
******************************************************************************/
Future<PipelineFlowState> CreateBondsModifier::evaluateModifier(const ModifierEvaluationRequest& request, PipelineFlowState&& state)
{
    // Get modifier input.
    Particles* particles = state.expectMutableObject<Particles>();
    particles->verifyIntegrity();

    // The neighbor list cutoff.
    FloatType maxCutoff = uniformCutoff();
    // The list of per-type VdW radii.
    std::vector<FloatType> typeVdWRadiusMap;
    // Flags indicating which particle type(s) are hydrogens.
    std::vector<bool> isHydrogenType;

    // Build table of pair-wise cutoff radii.
    const Property* typeProperty = nullptr;
    std::vector<std::vector<FloatType>> pairCutoffSquaredTable;
    if(cutoffMode() == PairCutoff) {
        maxCutoff = 0;
        typeProperty = particles->expectProperty(Particles::TypeProperty);
        if(typeProperty) {
            for(auto entry = pairwiseCutoffs().begin(); entry != pairwiseCutoffs().end(); ++entry) {
                FloatType cutoff = entry.value();
                if(cutoff > 0) {
                    const ElementType* ptype1 = lookupParticleType(typeProperty, entry.key().first);
                    const ElementType* ptype2 = lookupParticleType(typeProperty, entry.key().second);
                    if(ptype1 && ptype2 && ptype1->numericId() >= 0 && ptype2->numericId() >= 0) {
                        int id1 = ptype1->numericId();
                        int id2 = ptype2->numericId();
                        if((int)pairCutoffSquaredTable.size() <= std::max(id1, id2)) pairCutoffSquaredTable.resize(std::max(id1, id2) + 1);
                        if((int)pairCutoffSquaredTable[id1].size() <= id2) pairCutoffSquaredTable[id1].resize(id2 + 1, FloatType(0));
                        if((int)pairCutoffSquaredTable[id2].size() <= id1) pairCutoffSquaredTable[id2].resize(id1 + 1, FloatType(0));
                        pairCutoffSquaredTable[id1][id2] = cutoff * cutoff;
                        pairCutoffSquaredTable[id2][id1] = cutoff * cutoff;
                        if(cutoff > maxCutoff) maxCutoff = cutoff;
                    }
                }
            }
            if(maxCutoff <= 0)
                throw Exception(tr("At least one positive bond cutoff must be set for a valid pair of particle types."));
        }
    }
    else if(cutoffMode() == TypeRadiusCutoff) {
        maxCutoff = 0;
        if(vdwPrefactor() <= 0.0)
            throw Exception(tr("Van der Waal radius scaling factor must be positive."));
        typeProperty = particles->expectProperty(Particles::TypeProperty);
        if(typeProperty) {
            for(const ElementType* type : typeProperty->elementTypes()) {
                if(const ParticleType* ptype = dynamic_object_cast<ParticleType>(type)) {
                    if(ptype->vdwRadius() > 0.0 && ptype->numericId() >= 0) {
                        if(ptype->vdwRadius() > maxCutoff)
                            maxCutoff = ptype->vdwRadius();
                        if(type->numericId() >= typeVdWRadiusMap.size())
                            typeVdWRadiusMap.resize(type->numericId() + 1, 0.0);
                        typeVdWRadiusMap[type->numericId()] = ptype->vdwRadius();
                        if(skipHydrogenHydrogenBonds()) {
                            if(type->numericId() >= isHydrogenType.size())
                                isHydrogenType.resize(type->numericId() + 1, false);
                            isHydrogenType[type->numericId()] = (ptype->name() == QStringLiteral("H"));
                        }
                    }
                }
            }
            maxCutoff *= vdwPrefactor() * 2.0;
            if(maxCutoff == 0.0)
                throw Exception(tr("The van der Waals (VdW) radii of all particle types are undefined or zero. Creating bonds based on the VdW radius requires at least one particle type with a positive radius value."));
        }
        OVITO_ASSERT(!typeVdWRadiusMap.empty());
    }
    if(maxCutoff <= 0.0)
        throw Exception(tr("Maximum bond cutoff range is zero. A positive value is required."));

    // Get molecule IDs.
    const Property* moleculeProperty = onlyIntraMoleculeBonds() ? particles->getProperty(Particles::MoleculeProperty) : nullptr;

    // Create the bonds object that will store the generated bonds.
    if(!particles->bonds()) {
        DataOORef<Bonds> bondsObj = DataOORef<Bonds>::create(ObjectInitializationFlag::DontCreateVisElement);
        bondsObj->setCreatedByNode(request.modificationNode());
        bondsObj->setVisElement(bondsVis());
        particles->setBonds(std::move(bondsObj));
    }

    // Perform the main work in a separate thread.
    return asyncLaunch([
            state = std::move(state),
            particles,
            maxCutoff,
            minCutoff = minimumCutoff(),
            moleculeIDs = moleculeProperty,
            particleTypes = typeProperty,
            bondType = DataOORef<BondType>::makeDeepCopy(bondType()), // Note: Passing a deep copy of the original bond type to the data pipeline.
            pairCutoffsSquared = std::move(pairCutoffSquaredTable),
            typeVdWRadiusMap = std::move(typeVdWRadiusMap),
            vdwPrefactor = vdwPrefactor(),
            isHydrogenType = std::move(isHydrogenType),
            autoDisableBondDisplay = autoDisableBondDisplay(),
            createdByNode = request.modificationNodeWeak()]() mutable
    {
        TaskProgress progress(this_task::ui());
        progress.setText(tr("Generating bonds"));

        OVITO_ASSERT(state.data()->dataReferenceCount() == 1);
        OVITO_ASSERT(state.data()->isSafeToModify());
        OVITO_ASSERT(particles->isSafeToModify());
        Bonds* bonds = particles->makeBondsMutable();
        bonds->verifyIntegrity();

        // Prepare the neighbor finder.
        CutoffNeighborFinder neighborFinder(maxCutoff, particles->expectProperty(Particles::PositionProperty), state.getObject<SimulationCell>(), {});

        // The lower bond length cutoff squared.
        FloatType minCutoffSquared = minCutoff * minCutoff;

        BufferReadAccess<IdentifierIntType> moleculeIDsArray(moleculeIDs);
        BufferReadAccess<int32_t> particleTypesArray(particleTypes);

        // Generate bonds.
        size_t particleCount = particles->elementCount();
        // Multi-threaded loop over all particles, each thread producing a partial bonds list.
        auto partialBondsLists = parallelForCollect<std::vector<Bond>>(particleCount, 4096, progress, [&](size_t particleIndex, std::vector<Bond>& bondList) {

            // Get the type of the central particles.
            int type1;
            bool isHydrogenType1 = false;
            if(particleTypesArray) {
                type1 = particleTypesArray[particleIndex];
                if(type1 < 0) return;
                if(type1 < isHydrogenType.size())
                    isHydrogenType1 = isHydrogenType[type1];
            }

            // Kernel called for each particle: Iterate over the particle's neighbors within the cutoff range.
            for(CutoffNeighborFinder::Query neighborQuery(neighborFinder, particleIndex); !neighborQuery.atEnd(); neighborQuery.next()) {
                if(neighborQuery.distanceSquared() < minCutoffSquared)
                    continue;
                if(moleculeIDsArray && moleculeIDsArray[particleIndex] != moleculeIDsArray[neighborQuery.current()])
                    continue;

                if(particleTypesArray) {
                    int type2 = particleTypesArray[neighborQuery.current()];
                    if(type2 < 0) continue;
                    if(type1 < (int)typeVdWRadiusMap.size() && type2 < (int)typeVdWRadiusMap.size()) {
                        // Do not generate H-H bonds (if requested).
                        if(isHydrogenType1 && type2 < isHydrogenType.size()) {
                            if(isHydrogenType[type2])
                                continue;
                        }
                        FloatType cutoff = vdwPrefactor * (typeVdWRadiusMap[type1] + typeVdWRadiusMap[type2]);
                        if(neighborQuery.distanceSquared() > cutoff*cutoff)
                            continue;
                    }
                    else if(type1 < (int)pairCutoffsSquared.size() && type2 < (int)pairCutoffsSquared[type1].size()) {
                        if(neighborQuery.distanceSquared() > pairCutoffsSquared[type1][type2])
                            continue;
                    }
                    else continue;
                }

                Bond bond = { particleIndex, neighborQuery.current(), neighborQuery.unwrappedPbcShift() };

                // Skip every other bond to create only one bond per particle pair.
                if(!bond.isOdd())
                    bondList.push_back(bond);
            }
        });

        // Flatten the bonds list into a single std::vector.
        size_t totalBondCount = boost::accumulate(partialBondsLists, (size_t)0, [](size_t n, const std::vector<Bond>& bonds) { return n + bonds.size(); });
        std::vector<Bond> totalBondsList;
        if(totalBondCount != 0) {
            totalBondsList = std::move(partialBondsLists.front());
            totalBondsList.reserve(totalBondCount);
            std::for_each(std::next(partialBondsLists.begin()), partialBondsLists.end(), [&](const std::vector<Bond>& bonds) { totalBondsList.insert(totalBondsList.end(), bonds.begin(), bonds.end()); });
            this_task::throwIfCanceled();
        }

        // Insert bonds into Bonds container.
        size_t numGeneratedBonds = bonds->addBonds(totalBondsList, nullptr, particles, {}, std::move(bondType));

        // Output the number of newly added bonds to the pipeline.
        state.addAttribute(QStringLiteral("CreateBonds.num_bonds"), QVariant::fromValue(numGeneratedBonds), createdByNode);

        // If the total number of bonds is unusually high, we better turn off bonds display to prevent the program from freezing.
        if(bonds->elementCount() > 2000000 && autoDisableBondDisplay && this_task::isInteractive()) {
            if(BondsVis* vis = bonds->visElement<BondsVis>()) {
                // Modifying the vis element must be done in the main thread.
                launchDetached(ObjectExecutor(vis), [vis]() {
                    this_task::ui()->performTransaction(tr("Disable bonds display"), [&]() {
                        vis->setEnabled(false);
                    });
                });
            }
            state.setStatus(PipelineStatus(PipelineStatus::Warning, tr("Created %1 bonds, which is a lot. The display of bonds has been turned off as a precaution. You can manually turn it on again if needed.").arg(numGeneratedBonds)));
        }
        else {
            state.setStatus(PipelineStatus(PipelineStatus::Success, tr("Created %1 bonds.").arg(numGeneratedBonds)));
        }

        return std::move(state);
    });
}

}   // End of namespace
