Source code for polymerist.mdtools.openfftools.partition

'''For refining a residue cover of an OpenFF Molecule into a partition, recording partition assignments to atom metadata.
 Must be called after successful call to openff.toolkit.topology.topology.from_pdb()
Code derived from original implementation and conceptualization by Connor Davel (https://github.com/openforcefield/polymer_examples/blob/main/monomer_generation/partition.py) '''

__author__ = 'Connor Davel, Timotej Bernat'
__email__ = 'timotej.bernat@colorado.edu'

from typing import TypeAlias

import json
from itertools import combinations
from collections import defaultdict

import networkx as nx
import numpy as np

from openff.toolkit.topology import Topology


# CUSTOM TYPEHINTS
BondInfoDict : TypeAlias = dict[tuple[int, int], int] # dictionary mapping bond start-and-end atom indices to the bond indices

# HELPER UTILITY FUNCTIONS
def _create_choice_graph(isomorphism_info : tuple[list[int], BondInfoDict]) -> nx.Graph:
    # TODO : add docstring description
    ''''''
    isomorphism_info = list(isomorphism_info.values()) # original ids are of no use
    all_values = set()
    for isomorphism in isomorphism_info:
        iso_ids, adjacencies = isomorphism
        for value in iso_ids:
            all_values.add(value)
    max_val = max(all_values)
    min_val = min(all_values)
    # create a matrix to find atom overlaps
    matrix = np.zeros((len(isomorphism_info), max_val + 1))
    adjacency_list = []
    for idx, isomorphism in enumerate(isomorphism_info):
        iso_ids, adjacencies = isomorphism
        cap_ids = []
        adjacencies_tupled = [] # adjacencies in tuple form for easier searching
        for bond, bond_type in adjacencies.items():
            edge_atom, cap_atom = bond
            cap_ids.append(cap_atom)
            adjacencies_tupled.append(tuple([edge_atom, cap_atom, bond_type]))
        adjacency_list.append(adjacencies_tupled)

        matrix[idx, list(set(iso_ids) - set(cap_ids))] = 1

    # matrix_data and adjacency_list are parallel to eachother
    # each item in adjacency_list corresponds to the isomorphism found in the 
    # corresponding row of matrix_data
    n_rows, n_cols = matrix.shape
    G = nx.Graph()
    all_bonds = []
    for iso_id, isomorphism in enumerate(adjacency_list):
        n_atoms = len(isomorphism_info[iso_id][0])
        if isomorphism == []:
            continue
        neighbors = []
        for bond in isomorphism:
            edge_a, cap_a, bond_type = bond
            all_bonds.append((edge_a, cap_a, bond_type, iso_id))

        row, = np.nonzero(np.squeeze(np.asarray(matrix[iso_id, :])))
        row = list(row)
        # find overlaps using matrix 
        overlaps, = np.nonzero(np.asarray(matrix[:, row].sum(axis=1)).reshape(-1))
        overlaps = list(overlaps)
        overlaps.remove(iso_id)
        G.add_node(
            iso_id,
            selected = False,
            overlaps = overlaps,
            bonds = isomorphism,
            n_atoms = n_atoms
        )
    for bond1, bond2 in combinations(all_bonds, 2):
        edge_a1, cap_a1, bond_type1, iso_id1 = bond1
        edge_a2, cap_a2, bond_type2, iso_id2 = bond2
        if (edge_a2, cap_a2, bond_type2) == (cap_a1, edge_a1, bond_type1):
            G.add_edge(
                iso_id1,
                iso_id2,
                order = bond_type1
            )

    return G

def _traverse(starting_queue : list[int], graph : nx.Graph) -> list[int]:
    # TODO : add docstring description
    found_nodes = []
    found_edges = []
    queue = starting_queue

    while len(queue) > 0:
        if queue == []: # TB : isn't this check redundant immediately after the while loop?
            return found_nodes
        
        v = queue.pop(0)
        found_nodes.append(v)
        for bond in graph.nodes[v]['bonds']:
            edge_a, cap_a, bond_type = bond
            if {edge_a, cap_a} in found_edges:
                continue

            selected_neighbor = -1
            for neighbor in graph.neighbors(v):
                # make sure this is the neighbor with the correct bond info
                neighbor_node = graph.nodes[neighbor]
                if (cap_a, edge_a, bond_type) not in neighbor_node['bonds']:
                    continue
                if neighbor_node['selected'] == True:
                    selected_neighbor = -1
                    break
                selected_neighbor = neighbor

            if selected_neighbor >= 0:
                queue.append(selected_neighbor)
                found_edges.append({edge_a, cap_a})

    return found_nodes

# MAIN PARTITIONING FUNCTION
[docs] def partition(offtop : Topology) -> bool: """ args ---- offtop: Openforcefield toolkit topology with complete metadata following successful call to openff.toolkit.topology.topology.from_pdb(). returns ------- True if partition was successful False otherwise. This does not mean that the resulting assignment is invalid. It only means that clean partitioning cannot be achieved and some matches necessarily overlap. Try choosing other substructures if partitioning is necessary. """ for offmol in offtop.molecules: # first, populate isomorphism_info, which has the following format: # [([target_ids], [{(map_start, map_end): bond_type}])] # isomorphism_info = dict() target_ids = defaultdict(list) bond_info = defaultdict(dict) match_names = defaultdict(str) for atom in offmol.atoms: iso_info = { int(k) : entry for k, entry in json.loads(atom.metadata["match_info"]).items() } for query_num, query_name_id in iso_info.items(): match_name = query_name_id[0] # id of the query atom query_id = query_name_id[1] # total number of query match_names[query_num] = match_name if atom.molecule_atom_index not in target_ids[query_num]: target_ids[query_num].append(atom.molecule_atom_index) for b in atom.bonds: # try to find inter-monomer bonds if b.atom1_index == atom.molecule_atom_index: begin_atom = b.atom1 end_atom = b.atom2 else: begin_atom = b.atom2 end_atom = b.atom1 if query_num in [int(k) for k in json.loads(end_atom.metadata["match_info"]).keys()]: neighbor = False # TB : could make into explicit boolean? else: neighbor = True if neighbor: # TB : could just perform above check bond_entry = tuple([begin_atom.molecule_atom_index, end_atom.molecule_atom_index]) if bond_entry not in bond_info[query_num]: bond_info[query_num][bond_entry] = b.bond_order if len(bond_info) == 0: continue assert len(target_ids) == len(bond_info) isomorphism_info_dict = {idx: (target_ids[idx], bond_info[idx]) for idx in target_ids} idx_to_match_id = {idx: match_id for idx, match_id in enumerate(target_ids)} choice_G = _create_choice_graph(isomorphism_info_dict) connected_component_counts = [] biggest_chain = None biggest_chain_length = 0 for chain in nx.connected_components(choice_G): subgraph = choice_G.subgraph(chain) not_searched_nodes = list(subgraph.nodes) while len(not_searched_nodes) != 0: unique_group = _traverse([not_searched_nodes[0]], subgraph) new_tally = 0 old_tally = 0 overlapping_nodes = [] for node in unique_group: new_tally += subgraph.nodes[node]['n_atoms'] for overlapping_node in subgraph.nodes[node]['overlaps']: if overlapping_node in subgraph.nodes and subgraph.nodes[overlapping_node]['selected'] == True: overlapping_nodes.append(overlapping_node) old_tally += subgraph.nodes[overlapping_node]['n_atoms'] if new_tally > old_tally: # exchange new for old choices for node in unique_group: subgraph.nodes[node]['selected'] = True for node in overlapping_nodes: subgraph.nodes[node]['selected'] = False [not_searched_nodes.remove(i) for i in unique_group if i in not_searched_nodes] # TB : why is this a list comprehension? tally = 0 for node in subgraph.nodes: if subgraph.nodes[node]['selected']: tally += subgraph.nodes[node]['n_atoms'] if tally > biggest_chain_length: biggest_chain = subgraph biggest_chain_length = tally # finally pick from unique_mapping_groups the largest list largest_mapping_group = [] for node in biggest_chain.nodes: if biggest_chain.nodes[node]['selected']: largest_mapping_group.append(node) all_ids = [] for iso_id in largest_mapping_group: all_ids += target_ids[idx_to_match_id[iso_id]] if set(all_ids) == set([a.molecule_atom_index for a in offmol.atoms]): # assign metadata for iso_id in largest_mapping_group: res_id = idx_to_match_id[iso_id] molecule_ids = target_ids[res_id] match_name = match_names[res_id] for m_id in molecule_ids: atom = offmol.atom(m_id) iso_info = { int(k) : entry for k,entry in json.loads(atom.metadata["match_info"]).items() } picked_match = iso_info[res_id] # new metadata atom.metadata["residue_name"] = picked_match[0] atom.metadata["substructure_query_id"] = picked_match[1] atom.metadata["residue_number"] = res_id else: return False # if one mol fails, return False return True # if all are successful, return True