Source code for npfc.draw

"""
Module draw
==============
This modules contains methods for drawing images of molecules with
highlighted fragments.
A special care was given to blending colors for overlapping fragments.
"""

# standard
import logging
from pathlib import Path
from math import sqrt
from copy import deepcopy
import math
# data handling
import json
import base64
import numpy as np
from collections import OrderedDict
from itertools import chain
from collections import Counter
# chemoinformatics
from rdkit.Chem import AllChem
from rdkit.Chem import rdChemReactions
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem import Mol
from rdkit.Chem import Atom
from rdkit.Chem import Bond
from rdkit.Chem import Draw
# 2D depiction of molecules
from rdkit.Chem import rdDepictor
from rdkit.Chem import rdCoordGen
# graph
import networkx as nx
# docs
import matplotlib
import matplotlib.pyplot as plt  # required for creating a canvas for displaying graphs
from matplotlib.figure import Figure
from networkx.classes.graph import Graph
from networkx.drawing.nx_agraph import to_agraph
import seaborn as sns
from PIL import Image
from typing import Union
from typing import Set
from typing import List
from typing import Tuple
from typing import Dict
# dev
from npfc import utils
from npfc import fragment_combination_graph
# tmp
from rdkit import Chem, Geometry
from rdkit.Chem import AllChem, rdCoordGen
from scipy.spatial import KDTree
from IPython.display import Image
from npfc import notebook



# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GLOBALS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


# in python 3.7+ dict are odered so red will always #1, green #2, etc.
# but I got a bug so now I keep using a nice defined OrderedDict
DEFAULT_PALETTE = OrderedDict()
DEFAULT_PALETTE['red'] = (1.0, 0.6, 0.6)
DEFAULT_PALETTE['green'] = (0.2, 1.0, 0.2)
DEFAULT_PALETTE['blue'] = (0.4, 0.6, 1.0)
DEFAULT_PALETTE['orange'] = (0.9569, 0.6667, 0.2588)
DEFAULT_PALETTE['purple'] = (0.8392, 0.6275, 0.7686)
DEFAULT_PALETTE['yellow'] = (1.0, 0.9294, 0.0)
DEFAULT_PALETTE['teal'] = (0.5725, 0.9608, 0.9882)
DEFAULT_PALETTE['gray'] = (0.7294, 0.7294, 0.7294)

# # Another default palette to use in reports
# DEFAULT_PALETTE_255 = OrderedDict()
# for k, v in DEFAULT_PALETTE.items():
#     v = tuple([int(x * 255) for x in v])
#     DEFAULT_PALETTE_255[k] = v


# # nice to have the values as hexadecimals too I guess
# def rgb255_to_hex(rgb: tuple):
#     r, g, b = tuple(rgb)
#     return '#{:02x}{:02x}{:02x}'.format(r, g, b)


# DEFAULT_PALETTE_HEX = OrderedDict()
# for k, v in DEFAULT_PALETTE_255.items():
#     v_hex = rgb255_to_hex(v)
#     DEFAULT_PALETTE_HEX[k] = v_hex.upper()

# #
# #
#
# DEFAULT_PALETTE = {'red': (1.0, 0.6, 0.6),
#                    'green': (0.2, 1.0, 0.2),
#                    'blue': (0.4, 0.6, 1.0),
#                    'orange': (0.9569, 0.6667, 0.2588),
#                    'purple': (0.8392, 0.6275, 0.7686),
#                    'yellow': (1.0, 0.9294, 0.0),
#                    'teal': (0.5725, 0.9608, 0.9882),
#                    'gray': (0.7294, 0.7294, 0.7294),
#                    }

# nice to have the values as hexadecimals too I guess
DEFAULT_PALETTE_HEX = {k: matplotlib.colors.to_hex(v) for k, v in DEFAULT_PALETTE.items()}

# matplotlib.colors.to_rgb('#FF0000')
# matplotlib.colors.to_hex(((1.0, 0.0, 0.0)))

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #




[docs]def get_d_aidxs_for_rings(mol: Mol, fuse_rings: bool = False) -> dict: """Return a dictionary of atom indices (d_aidxs) for highlighting rings in molecules. This is purely a helper function to instanciate a ColorMap object, which can highlight SSSR or ring systems instead of fragments. :param mol: the input molecule :param fuse_rings: if False: highlight SSSR, if True: highlight fused ring systems :return: a dictionary of syntax: {"R#0": [(0, 1, 2, 3, 4)], "R#1": [(5, 6, 7, 8, 9)]}. It is a list of tuple for compatibility reasons with fragment highlighting. It could also be edited to highlight rings of a certain size, etc. """ ring_atoms_l = mol.GetRingInfo().AtomRings() if fuse_rings: ring_atoms_l = utils.fuse_rings(ring_atoms_l) return {f"R#{i}": [ring_atoms] for i, ring_atoms in enumerate(ring_atoms_l)}
[docs]def mol(mol: Mol, colormap: 'ColorMap' = None, output_file: str = None, img_size: Tuple[int] = (400, 400), atom_labels: Union[str, dict] = None, # force_depict: bool = False, svg: bool = True, legend: str = '') -> Image: """ Draw an Image of a molecule with highlighted atoms and bonds according to a colormap. If no Colormap object is provided, no highlighting is done. This code is based on the 2020.03 RDKit release and the picture below is not yet updated. .. image:: _images/draw_highlight.svg :align: center :param mol: the molecule to highlight :param colormap: the colormap to use for highlighting the molecule :param execlude_exocyclic_from_highlight: since exocyclic atoms are not used for fc classification, this option allows the user to mask exocyclic atoms from highlights :param output_file: if speficied, the image is saved (format is deduced from extension) :param img_size: the size of the resulting Image :param atom_labels: display atom labels. Parameter can either value 'atom_indices' or a dictionary with atom_index: label (i.e. fcp). :param svg: use SVG format instead of PNG :return: an Image of the highlighted molecule """ # if no colormap is provided, do not highlight any atoms if colormap is None: colormap = ColorMap(mol, {}) # draw if svg: d2d = rdMolDraw2D.MolDraw2DSVG(img_size[0], img_size[1]) else: d2d = rdMolDraw2D.MolDraw2DCairo(img_size[0], img_size[1]) if atom_labels is not None: if atom_labels == 'atom_indices': d2d.drawOptions().addAtomIndices = True else: for at in mol.GetAtoms(): at_idx = at.GetIdx() label = atom_labels[at_idx] at.SetProp('atomNote', label) # general settings d2d.drawOptions().legendFontSize = 24 d2d.drawOptions().padding = 0.1 # generate image d2d.DrawMoleculeWithHighlights(mol, legend, colormap.atoms, colormap.bonds, {}, {}) d2d.FinishDrawing() img = d2d.GetDrawingText() # export image if output_file is not None: output_ext = output_file.split('.')[-1].upper() if output_ext == 'SVG' and not svg: raise ValueError("Error! output file extension is SVG but image format is PNG!") if output_ext == 'SVG': with open(output_file, 'w') as SVG: SVG.write(img) elif output_ext != 'PNG': raise ValueError(f"Error! Unsupported extension '{output_ext}'!") return img
[docs]def mols(mols: List[Mol], colormaps: List['ColorMap'] = [], output_file: str = None, sub_img_size: Tuple[int] = (300, 300), max_mols_per_row: int = 5, debug: bool = False, svg: bool = True, legends: List[str] = None, ): """ Draw an Image of a list of molecules with highlighted atoms and bonds according to a list of colormaps. ..warning:: This function is based on the old RDKit drawing code (< 2020.03) and thus it is necessary to blend colors (i.e. colormap.blend()) first to represent common atoms/bonds between fragments. :param mols: the molecules to highlight :param colormaps: the colormaps to use for highlighting the molecules :param output_file: if speficied, the image is saved (format is deduced from extension) :param sub_img_size: the size of the image of every molecule composing the grid :param max_mols_per_row: the maximum number of molecules displayed per row :param debug: display atom indices on the structure :param svg: use SVG format instead of PNG :return: an Image of the highlighted molecules """ atom_lists = [] colormaps_a = [] colormaps_b = [] if colormaps is None: for colormap in colormaps: atom_lists.append([]) colormaps_a.append({}) colormaps_b.append({}) else: for colormap in colormaps: atom_lists.append([int(x) for x in list(colormap.atoms.keys())]) colormaps_a.append(colormap.atoms) colormaps_b.append(colormap.bonds) if debug: mols = [Mol(mol) for mol in mols] for mol in mols: [mol.GetAtomWithIdx(idx).SetProp('molAtomMapNumber', str(mol.GetAtomWithIdx(idx).GetIdx())) for idx in range(mol.GetNumAtoms())] img = Draw.MolsToGridImage(mols, molsPerRow=max_mols_per_row, subImgSize=sub_img_size, highlightAtomLists=atom_lists, highlightAtomColors=colormaps_a, highlightBondColors=colormaps_b, useSVG=svg, legends=legends, ) # export img if output_file is not None: output_ext = output_file.split('.')[-1].upper() if output_ext == 'SVG' and not svg: raise ValueError("Error! output file extension is SVG but image format is PNG!") if output_ext == 'SVG': with open(output_file, 'w') as SVG: SVG.write(img.data) elif output_ext != 'PNG': raise ValueError(f"Error! Unsupported extension '{output_ext}'!") return img
def _get_edge_info(G: Graph, edge_attributes: List[str], attribute_names: bool, label_node_names_on_edges: bool) -> Dict: """ Use the first associated data of edges for edge labelling of a networkx graph. :param G: a Fragment Combination graph :param edge_attributes: a list of edge attributes to represent on the figure :param attribute_names: display the attribute names on the figure (name: value) :return: a Dict of syntax {(node1, node2): data} """ d = {} for edge in list(G.edges(data=True)): # determine what properties to keep by their names (defined as list in G attr) if edge_attributes is None: data = edge[2] else: data = {k: v for k, v in edge[2].items() if k in edge_attributes} # cannot use a dict for labelling edges, so just display values if attribute_names: data = [f"{k}: {v}" for k, v in data.items()] else: data = list(data.values()) if label_node_names_on_edges: data.append(f"s: {edge[0]}") data.append(f"t: {edge[1]}") # format them d[(edge[0], edge[1])] = '; '.join(data) return d
[docs]def graph(G: Graph, colormap_nodes: List[Tuple[float]] = None, output_file: str = None, fig_size: tuple = (8, 8), edge_attributes: List[str] = ['fcc'], attribute_names: bool = False, orientate: bool = False, label_node_names_on_edges: bool = False) -> Figure: """ Return a matplotlib Figure of a networkx graph. :param G: a networkx Graph object of the fragment combinations :param colormap_nodes: a colormap of RGB values for the nodes (i.e. [(0, 0, 1), (0, 1, 0)]) or a ColorMap object :param output_file: if speficied, the image is saved (format is deduced from extension) :param edge_attributes: a list of edge attributes to represent on the figure :param attribute_names: display the attribute names on the figure (name: value) :return: a matplotlib Figure object """ if isinstance(G, base64.bytes_types): G = utils.decode_object(G) # to orientate the graph, this is uselessly complicated, do not use this on larger networks! if orientate: H = nx.DiGraph(G, data=True) # this creates not only node1 -> node2 but also node2 -> node1 H.remove_edges_from(G.edges()) # this removes node1 -> node2 because it is found in the undirected graph G = H.reverse() # this transforms the remaining node2 -> node1 into node1 -> node2 if colormap_nodes is None: # define a 2D list instead of a single tuple to avoid matplotlib warning colormap_nodes = [(0.7, 0.7, 0.7)] * len(list(G.nodes())) elif isinstance(colormap_nodes, ColorMap): # define a list of colors mapped to the node iteration in G val_map = {k: v[0] for k, v in colormap_nodes.fragments.items()} colormap_nodes = [val_map.get(node, 0.0) for node in G.nodes()] # if fragment id is not found in colormap, paint node in black instead pos = nx.spring_layout(G) edges_info = _get_edge_info(G, edge_attributes, attribute_names, label_node_names_on_edges) figure = plt.figure(figsize=fig_size) nx.draw(G, pos, edge_color='black', width=1, linewidths=1, node_size=2000, node_color=colormap_nodes, alpha=1, font_size=16, with_labels=True, connectionstyle='arc3,rad=0.9', ) nx.draw_networkx_edge_labels(G, pos, edge_labels=edges_info, font_color='red', font_size=14, ) if output_file is not None: output_file_format = output_file.split('.')[-1].upper() plt.savefig(output_file, format=output_file_format) plt.close() return figure
[docs]def compress_parallel_edges(G): """This is an extremely unoptimized function for preprocessing FCG (networkx MultiGraphs) so they can be drawn more nicely without parallel edges. """ # get the edges as a df df_edges = nx.convert_matrix.to_pandas_edgelist(G) # init idm = df_edges.iloc[0]['idm'] idfcg = df_edges.iloc[0]['idcfg'] df_edges['n_fcc'] = 1 # 1st groupby: get the count of occurrences in groups of s, t, and fcc df_edges['n_fcc'] = df_edges.groupby(['source', 'target', 'fcc'])['n_fcc'].transform('sum') df_edges = df_edges.drop_duplicates(['source', 'target', 'fcc']) df_edges['fcc'] = df_edges['fcc'].map(lambda x: [x]) df_edges['n_fcc'] = df_edges['n_fcc'].map(lambda x: [x]) # currently unused # df_edges['cps'] = df_edges['cps'].map(lambda x: [x]) # display makes it hard to see but cps and cpt are string # df_edges['cpt'] = df_edges['cpt'].map(lambda x: [x]) # 2nd groupby: get the count of occurrences in groups of s, t df_edges = df_edges.groupby(['source', 'target']).agg({'fcc': 'sum', 'n_fcc': 'sum', # 'cps': 'sum', # 'cpt': 'sum', }) # apply general values df_edges['idm'] = idm df_edges['idcfg'] = idfcg df_edges['title'] = f"{idm}:{idfcg}" # format labels: "cm x4", "fs,fe" df_edges['label'] = df_edges.apply(lambda x: ', '.join([f"{x['fcc'][i]}" if x['n_fcc'][i] == 1 else f"{x['fcc'][i]} x{x['n_fcc'][i]}" for i in range(len(x['fcc']))]), axis=1) df_edges = df_edges.reset_index() return nx.from_pandas_edgelist(df_edges, source="source", target="target", edge_attr=['idm', 'idcfg', 'label', 'title']) # simple graph because no more parallel edges
[docs]def fcg(G, colormap=None, WD_img: str = None, output_file: str = None, size: tuple = (400, 400), title: str = '__graph__'): """A function to represent Fragment Combination Graphs. Currently, this function has limitations for showing fragments into the nodes: 1. Individual fragment images have to be generated beforehand and saved into the WD_img directory 2. Fragment images must be in PNG format 3. Fragment image files have to comply the naming scheme: "fragment_id.png" 4. Fragment image backgrounds must be transparent (otherwise the node color will not be visible) I am currently looking for a nicer and more portable solution, but I admit I am a bit stuck. To compute individual round PNG images of the fragments (and match criteria 1-3), one can use the code below: >>> # create a directory with individual fragment images (round, png) >>> mols_draw input_file output_dir Currently the generated PNG images are not transparent, so they have to be edited with ImageMagick: >>> # make all PNG files in output dir transparent (from bash) >>> cd output_dir >>> for f in *png; do echo $f; convert $f -transparent white $f; done :param G: a Graph (NetworkX), generated with the fragment_combination_graph module :param colormap: a matching ColorMap object, generated alongside the graph. If none is provided, the nodes will colored in white. :param WD_img: a working directory where suitable (see above) frament images are located. If none is provided, fragments will not be displayed in the graph nodes. :param size: the size of the canvas for the drawing :param title: add the specified title below the graph, by default molecule_id:fcg_id, or nothing if None :return: a drawing of the fragment combination graph """ # preprocess nx G G = compress_parallel_edges(G) # export from nx to Graphviz A = to_agraph(G) # configure graph attributes # A.graph_attr.update(ratio="fill") A.graph_attr.update(size="12, 25") A.graph_attr['outputorder'] = 'edgesfirst' A.graph_attr['forcelabels'] = 'true' A.graph_attr['nodesep'] = '2' A.graph_attr['dpi'] = '1200' A.graph_attr['fontsize'] = 25 if title == '__graph__': A.graph_attr['label'] = "\n\n" + list(G.edges(data=True))[0][2]['title'] elif title is not None: A.graph_attr['label'] = title # init node/attribute mapping node_labels = G.nodes() # configure node attributes for nl in node_labels: n = A.get_node(nl) n.attr['color'] = 'black' n.attr['style'] = 'filled' n.attr['imagescale'] = True n.attr['fixedsize'] = True n.attr['shape'] = 'circle' n.attr['labeldistance'] = 1 n.attr['penwidth'] = 1 n.attr['height'] = 2 n.attr['width'] = 2 n.attr['fontsize'] = 20 # embed fragment images, if available if WD_img is not None: n.attr['image'] = f"{WD_img}/{nl}.png" n.attr['label'] = "\n\n\n\n\n" + nl else: n.attr['label'] = nl # colormap if colormap is None: n.attr['fillcolor'] = "white" else: n.attr['fillcolor'] = matplotlib.colors.to_hex(colormap.fragments[nl][0]) # configure edge attributes for nxe, e in zip(sorted(G.edges(data=True), key=lambda x: (x[0], x[1])), A.edges()): e = A.get_edge(nxe[0], nxe[1]) e.attr['label'] = " " + nxe[2]['label'] e.attr['labelfontcolor'] = 'red' # setup export if output_file is None: output_file = '/tmp/_tmp_fcg.png' # export the graph as PNG (cannot embed SVG...) A.draw(output_file, format='png', prog='dot') # read back the export return Image(output_file, width=size[0], height=size[1])
[docs]def rescale(mol: Mol, f: float = 1.4): """Rescale the coordinates of a Mol with a factor f. :param mol: a Mol which is modified in place. :param f: the factor for rescaling coordinates """ tm = np.zeros((4, 4), np.double) for i in range(3): tm[i, i] = f tm[3, 3] = 1.0 AllChem.TransformMol(mol, tm)
[docs]def depict_mol(mol: Mol, methods: List[str] = ["CoordGen", "rdDepictor"], consider_input: bool = True) -> Mol: """ Returns the "best" 2D depiction of a molecule according the methods in METHODS_2D. Currently two methods are available: - CoordGen - rdDepictor A perfect score of 0 means the depiction is good enough (no overalapping atom/bonds) and it is not worth computing other depictions. When no perfect score is reached, the depiction with lowest score is retrieved. In case of tie, the first method applied is preferred. In case the input molecule contains input coordinates, they can be compared to the methods as 'Input' (lowest priority). The method used for depicting the molecule is stored as molecule property: "_2D". For CoordGen, 2D representations are automatically rescaled with a factor of 1.4. For some molecules, none of the methods yield a "perfect score". The depiction with the lowest score is thus selected. :param mol: the input molecule :param methods: a list of methods to apply. Currently supported: CoordGen, rdDepictor. :param consider_input: consider the input coordinates (if any), for determining the best 2D representation :return: the molecule with 2D coordinates and a new "_2D" property with the information of which depictor was selected. """ # methods METHODS = {'CoordGen': lambda x: rdCoordGen.AddCoords(x), 'rdDepictor': lambda x: rdDepictor.Compute2DCoords(x), 'Input': lambda x: x, } depictions = OrderedDict() for method in methods: # copy the input mol so input coordinates are not modified depiction_mol = Mol(mol) # compute the depiction (in place) METHODS[method](depiction_mol) # coordgen creates very small 2D representations, so let's rescale them if method == "CoordGen": rescale(depiction_mol) # score the depiction dv = DepictionValidator(depiction_mol) [a.SetProp('name', str(a.GetIdx())) for a in dv.mol.GetAtoms()] # bug fix for version 5+,add name property to atoms depiction_score = dv.depiction_score() # exit if perfect score, record depiction for selection otherwise if depiction_score == 0: depiction_mol.SetProp("_2D", method) return depiction_mol depictions[method] = (depiction_score, depiction_mol) # no perfect score was reached until now, so test input coordinates if any if consider_input and mol.GetNumConformers() > 0: method = "Input" dv = DepictionValidator(mol) [a.SetProp('name', str(a.GetIdx())) for a in dv.mol.GetAtoms()] # bug fix for version 5+,add name property to atoms depiction_score = dv.depiction_score() if depiction_score == 0: mol.SetProp("_2D", method) return mol depictions[method] = (depiction_score, mol) elif consider_input and mol.GetNumConformers() == 0: logging.debug("No input coordinates to use for Input method, so skipping it!") # retrieve best depiction possible best_method = min(depictions, key=lambda k: depictions[k][0]) best_depiction_mol = depictions[best_method][1] best_depiction_mol.SetProp("_2D", best_method) return best_depiction_mol
[docs]def reaction(mol1: Mol, mol2: Mol, sub_img_size: tuple = (200, 200), svg: bool = True, output_file: str = None): """Wrapper function around RDKit ReactionToImage function. If the molecules are Mol objects, they are converted to Smiles. If not, they are assumed to be already Smiles. .. warning:: There is currently no way of not displaying aromatic rings instead of kekulized rings. The original SMILES for the reaction can be displayed using a DEBUG logging level. :param mol1: the molecule to display left :param mol2: the molecule to display right :param sub_img_size: the size of the molecules :param svg: return the image in SVG text, return a PIL Image otherwise. :return: the reaction as an image. """ if isinstance(mol1, Mol): mol1 = Chem.MolToSmiles(mol1) if isinstance(mol2, Mol): mol2 = Chem.MolToSmiles(mol2) rxn_str = f"{mol1}>>{mol2}" logging.debug("rxn_str='%s'", rxn_str) rxn = rdChemReactions.ReactionFromSmarts(rxn_str, useSmiles=True) img = Draw.ReactionToImage(rxn, subImgSize=sub_img_size, useSVG=svg) # export img if output_file is not None: output_ext = output_file.split('.')[-1].upper() if output_ext == 'SVG' and not svg: raise ValueError(f"Error! output file extension is SVG but image format is PNG!") if output_ext == 'SVG': with open(output_file, 'w') as SVG: SVG.write(img) # inconsistent svg attributes between mols and reactions else: raise ValueError(f"Error! Unsupported extension '{output_ext}'!") return img
[docs]def highlight_fragment(mol, fragment_id, palette): """This function is to generate an image of a fragment with a given color. The whole fragment will be highlighted. It can be useful for more complex reports to show the colored fragments below a molecule with a colormap. :param mol: a molecule to highlight :param fragment_id: """ d_aidxs = {fragment_id: [[k for k in list(range(mol.GetNumAtoms()))]]} return FragmentHighlight(mol, d_aidxs, palette)
def display_fragments(df_fcg, fragment_colors, sort=True): d_frags = list(df_fcg['_d_mol_frags'].values) d_frags = {k: v for d in d_frags for k, v in d.items()} if sort: try: d_frags = {int(k): v for k, v in d_frags.items()} except ValueError: pass d_frags = dict(sorted(d_frags.items())) d_frags = {str(k): v for k, v in d_frags.items()} imgs_frags = [] for fragment_id, fragment in d_frags.items(): colormap = hilight_fragment(fragment, fragment_id, fragment_colors) imgs_frags.append(mol(fragment, colormap, img_size=(170, 170), legend=fragment_id)) return notebook.display_image_table(imgs_frags, max_img_per_row=10) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CLASSES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
[docs]class ColorMap: """A class containing all the required information for highlighting a Molecule. It is represented by the count of fragments, atom- and bond colors. """ def __init__(self, mol: Mol, d_aidxs: dict, palette: str = None, color_shades: float = 0.0, infer_colors: bool = True): """ :param mol: the molecule to highlight. Atom/Bond properties '_color' and 'num_colors' are modified in place. :param d_aidxs: a dictionary containing fragment ids as keys and molecule atom indices as values, i.e. {'frag1': [(0, 1, 2)]. 'frag2': [(2, 3, 4), (5, 6, 7)]} :param palette: a seaborn palette defined by a string. A list of possible palette names can be found at: http://www.python-simple.com/img/img45.png. If none is provided, an intern palette is used instead. :param color_shades: use a darker color shade each time a fragment is repeated in the molecule. By default the same color shade is applied each time. On the opposite, a value of 0.05 means 5% darker) :param infer_colors: use the matplotlib/seaborn parsing to determine colors. """ if color_shades < 0.0 or color_shades > 1.0: raise ValueError(f"Error! Argument color_shades value is expected to be found in the range [0.0, 1.0], but '{color_shades}' was found instead!") colormap_atoms = {} # atoms colormap_bonds = {} # bonds colormap_fragments = {} # fragments are colored and stored in order if palette is None: palette = DEFAULT_PALETTE else: palette = sns.color_palette(palette) palette = {f"COL_{str(i+1).zfill(2)}": color for i, color in enumerate(palette)} # pick a color for each fragment type for i, (fragment_id, aidxs_l) in enumerate(d_aidxs.items()): colors_k = list(palette.keys()) # colors are sorted, so always red, then green, etc. color = palette[colors_k[i % len(colors_k)]] # the great idea here is to loop back to the red once all colors have been used colormap_fragments[fragment_id] = [] # frag1: [(0, 0, 1)] # pick a shade for each occurrence of a same fragment type k = 0 for j, aidxs in enumerate(aidxs_l): # new_color = color # new_color = color # I actually find it harder to justify different shades new_color = tuple((x * (1.0 - color_shades * k) for x in color)) # 5% darker for each fragment of the same type ## TODO: define a range for j when colors get 10% lighter instead colormap_fragments[fragment_id].append(new_color) # color atoms for aidx in aidxs: if aidx not in colormap_atoms.keys(): colormap_atoms[aidx] = [new_color] else: colormap_atoms[aidx].append(new_color) # color bonds bidxs = self._get_bidxs(mol, aidxs) for bidx in bidxs: if bidx not in colormap_bonds.keys(): colormap_bonds[bidx] = [new_color] else: colormap_bonds[bidx].append(new_color) # continue to darken only if it does not end up being black if j < 6: k += 1 else: k = 0 # reset to default # init attributes self.fragments = colormap_fragments self.atoms = colormap_atoms self.bonds = colormap_bonds self.palette = palette def __repr__(self): """Return a string representation of the ColorMap object """ s = 'ColorMap={' frags = [] # listing a huge dictionary of tuples is not that great for representing data, so just display fragment colors for i, v in enumerate(self.fragments.items()): # find the corresponding key in the palette for the color value, so one can simply display the color name frags = [f"{v[0]}: {list(self.palette.keys())[list(self.palette.values()).index(v[1][0])]}" for i, v in enumerate(self.fragments.items())] return f"{s}{', '.join(frags)}" + '}' def _get_bidxs(self, mol: Mol, aidxs: Set[int]) -> Set[int]: """From an iterable of atom indices (aidxs), return a set of bond indices (bidxs) found between atoms in aidxs. :param mol: the molecule to highlight :param aidxs: a set with the atom indices of the fragment to highlight :return: a set with the corresponding bond indices """ # get all bonds attached to atom indices (as set: {0, 1, 2, 3}) bidxs = [] for aidx in aidxs: bonds = [b for b in mol.GetAtomWithIdx(aidx).GetBonds()] bidx = [b.GetIdx() for b in bonds] bidxs.append(bidx) # flatten the list of bond indices (bidx) bidxs_merged = list(chain.from_iterable(bidxs)) # all bonds # count number of times each bidx is found counter = Counter(bidxs_merged) # keep only bidx that appear > 1 return set({x: counter[x] for x in counter if counter[x] > 1}.keys()) # def get_fragment_colors_hex(self): # """Return the fragment colors as a dictionary of syntax: idf: color. # Returned format is hexadecimal. # """ # for idf, color_rgb in self.fragments.items():
[docs] def blend(self): """Blend colors found in a ColorMap. The ColorMap object is modified in place. """ for k, v in self.atoms.items(): self.atoms[k] = self._blend_multiple_colors(v) for k, v in self.bonds.items(): self.bonds[k] = self._blend_multiple_colors(v)
def _blend_two_colors(self, color1: Tuple[int], color2: Tuple[int], alpha: float = 0.5) -> Tuple[int]: """Blend two colors after Downgoat'reply at: https://stackoverflow.com/questions/726549/algorithm-for-additive-color-mixing-for-rgb-values The formala below is applied to each color channel (R, G, and B): .. math:: blended = \\sqrt{(1 - alpha) * val1^2 + alpha * val2^2} with val1: color channel in color1; val2: corresponding color channel in color2; alpha: transparency factor :param color1: the first color to blend :param color2: the second color to blend :param alpha: the blending factor, 0.5 means as much as color1 than color2 """ return tuple((sqrt((1-alpha) * (x1 ** 2) + alpha * (x2 ** 2)) for x1, x2 in zip(color1, color2))) def _blend_multiple_colors(self, colors: List[Tuple[int]]) -> List[Tuple[int]]: """Given an iterable of colors represented in RGB format (i.e. [(1, 0, 0), (0, 1, 0), (0, 0, 1)]), iteratively apply the blend_two_colors function. For instance, if colors are red, green, and blue (respectively from above), then: - first blend red and green applying an alpha """ num_colors = len(colors) # if just 1 color, just return it as it is if num_colors < 2: return colors # iterate over colors color_result = colors[0] # initiate at the first color for i in range(1, len(colors)): color_to_add = colors[i] old_alpha = 1 - (1/(i + 1)) alpha = 1 - old_alpha color_result = self._blend_two_colors(color_result, color_to_add, alpha) # when I mix red and green, I want to get some yellow, not some olive color, # so here I hardcode the yellow I long for (golden yellow) if tuple(round(x, 4) for x in color_result) == (0.7211, 0.8246, 0.4472): color_result = (1.0, 0.9294, 0.0) # in case of the same color being applied, use a 10% darker shade if color_result == color_to_add: color_result = tuple((x * 0.90 for x in color_result)) return [color_result]
def cap_rgb_val(x): if x > 1: return 0.1 # not completely black elif x < 0: return 0.0 # not completely white return x def attribute_colors_to_fragments(d_aidxs, palette, color_gradient=0.1): # init highlights = {k: [] for k in d_aidxs} num_colors = palette.num_colors j = 0 # current color index k = 0 # gradient that gets incremented each time all colors have been used (makes colors darker/lighter than initial) darker = True # begin for i, fragment_id in enumerate(d_aidxs): # attribute the color to the fragment, while applying the gradient if darker: highlights[fragment_id] = tuple((x * (1.0 + color_gradient * k) for x in palette.colors[j])) else: highlights[fragment_id] = tuple((x * (1.0 - color_gradient * k) for x in palette.colors[j])) highlights[fragment_id] = tuple([cap_rgb_val(x) for x in highlights[fragment_id]]) # since there is a limited number of colors, recycle them j+=1 if j >= num_colors: j = 0 darker = not darker # toggle boolean # increment gradient once both lighter and darker shades have been used if not darker: k += 1 # increment color gradient for next iterations return highlights class Palette: def __init__(self, colors=None): # if no colors, then use default values if colors is None: colors = list(DEFAULT_PALETTE.values()) else: if isinstance(colors, str): colors = sns.color_palette(colors) if isinstance(colors, dict): colors = list(colors.values()) self.colors = [matplotlib.colors.to_rgb(x) for x in colors] self.colors_ini = colors self.num_colors = len(colors) def show(self): return sns.palplot(sns.color_palette(self.colors), size=1) def __repr__(self): s = f"Palette ({self.num_colors} color" if self.num_colors > 1: return s + 's)' else: return s + ')'
[docs]class FragmentHighlight: """A class containing all the required information for highlighting a molecule's fragments. It is represented by the count of fragments, atom- and bond colors. Destined to replace the ColorMap class. """ def __init__(self, mol: Mol, atoms_to_highlight: dict, fragments_colors: dict = None, palette: Palette = None, color_gradient: float = 0.2): """ :param mol: the molecule to highlight. :param atoms_to_highlight: a dictionary containing fragment ids as keys and molecule atom indices as values, i.e. {'frag1': [(0, 1, 2)]. 'frag2': [(2, 3, 4), (5, 6, 7)]}. If None and fragmens_colors are defined, then all atoms are highlighted (useful for coloring fragments). :param fragments_colors: a dictionary attributing a color to each fragment id. Fragment ids have to match those defined in atoms_to_highlight. If a fragment id is missing, then it will not be highlighted. If this argument is not set, the palette will be used to attribute colors to all fragments defined in atoms_to_highlight. :param palette: a Palette object used to attribute colors to fragments, when fragments_colors are not defined. If none is defined, the default palette will be used. :param color_gradient: recycle colors with a darker shade when all colors have been already used. Color gradient should vary between 0 (stays the same) and 1 (fully black). """ # define palette only if necessary if palette is None and fragments_colors is None: palette = Palette(list(DEFAULT_PALETTE.values())) elif palette is not None and fragments_colors is not None: print("Warning! Palette is used only when fragments_colors are not specified.") # attribute colors to fragments if not already the case if fragments_colors is None: fragments_colors = attribute_colors_to_fragments(atoms_to_highlight, palette, color_gradient) # color_gradient if color_gradient < 0.0 or color_gradient > 1.0: raise ValueError(f"Error! Argument color_gradient value is expected to be found in the range [0.0, 1.0], but '{color_gradient}' was found instead!") # atoms highlight_atoms = {k: [] for k in sorted(list(chain.from_iterable(chain.from_iterable(atoms_to_highlight.values()))))} # bonds highlight_bonds = {} # bonds # apply colors for i, (fragment_id, aidxs_l) in enumerate(atoms_to_highlight.items()): if fragment_id not in fragments_colors.keys(): continue # skip in case someone feeds atoms_to_highlight without a fragment_id to not highlight it for j, aidxs in enumerate(aidxs_l): # color atoms for aidx in aidxs: highlight_atoms[aidx].append(fragments_colors[fragment_id]) # color bonds bidxs = self._get_bidxs(mol, aidxs) for bidx in bidxs: if bidx not in highlight_bonds.keys(): highlight_bonds[bidx] = [fragments_colors[fragment_id]] else: highlight_bonds[bidx].append(fragments_colors[fragment_id]) # init attributes self.fragments = {k: v for k, v in fragments_colors.items() if k in atoms_to_highlight.keys()} self.num_fragments = len(self.fragments.keys()) self.atoms = highlight_atoms self.bonds = highlight_bonds def __repr__(self): if self.num_fragments > 1: return f"FragmentHighlight (n=%d)" % self.num_fragments else: return f"FragmentHighlight (n=%d)" % self.num_fragments def _get_bidxs(self, mol: Mol, aidxs: Set[int]) -> Set[int]: """From an iterable of atom indices (aidxs), return a set of bond indices (bidxs) found between atoms in aidxs. :param mol: the molecule to highlight :param aidxs: a set with the atom indices of the fragment to highlight :return: a set with the corresponding bond indices """ # get all bonds attached to atom indices (as set: {0, 1, 2, 3}) bidxs = [] for aidx in aidxs: bonds = [b for b in mol.GetAtomWithIdx(aidx).GetBonds()] bidx = [b.GetIdx() for b in bonds] bidxs.append(bidx) # flatten the list of bond indices (bidx) bidxs_merged = list(chain.from_iterable(bidxs)) # all bonds # count number of times each bidx is found counter = Counter(bidxs_merged) # keep only bidx that appear > 1 return set({x: counter[x] for x in counter if counter[x] > 1}.keys())
[docs]class DepictionValidator: """ Toolkit for estimation of depiction quality. This is not my code, it was copied from: https://gitlab.ebi.ac.uk/pdbe/ccdutils/blob/master/pdbeccdutils/core/depictions.py I did this because everytime I try to create an environment using the pdbeccdutils library, I have issues and I have to try to use an older version. So I just extracted the little piece that is required for my project, so I don't have problems anymore. """ def __init__(self, mol): self.mol = mol self.conformer = mol.GetConformer() self.bonds = self.mol.GetBonds() atoms = [self.conformer.GetAtomPosition(i) for i in range(0, self.conformer.GetNumAtoms())] atom_centers = [[atom.x, atom.y, atom.z] for atom in atoms] self.kd_tree = KDTree(atom_centers) def _intersection(self, bondA, bondB): """ True if two bonds collide, false otherwise. Note that false is retrieved even in case the bonds share common atom, as this is not a problem case. Cramer's rule is used for the linear equations system. Args: bondA (rdkit.Chem.rdchem.Bond): this bond bondB (rdkit.Chem.rdchem.Bond): other bond Returns: bool: true if bonds share collide, false otherwise. """ atoms = [bondA.GetBeginAtom(), bondA.GetEndAtom(), bondB.GetBeginAtom(), bondB.GetEndAtom()] names = [a.GetProp('name') for a in atoms] points = [self.conformer.GetAtomPosition(a.GetIdx()) for a in atoms] vecA = Geometry.Point2D(points[1].x - points[0].x, points[1].y - points[0].y) vecB = Geometry.Point2D(points[3].x - points[2].x, points[3].y - points[2].y) # we need to set up directions of the vectors properly in case # there is a common atom. So we identify angles correctly # e.g. B -> A; B -> C and not A -> B; C -> B. if len(set(names)) == 3: angle = self.__get_angle(names, vecA, vecB) return angle < 10.0 # Cramer's rule to identify intersection det = vecA.x * -vecB.y + vecA.y * vecB.x if round(det, 2) == 0.00: return False a = points[2].x - points[0].x b = points[2].y - points[0].y detP = (a * -vecB.y) - (b * -vecB.x) p = round(detP / det, 3) if (p < 0 or p > 1): return False detR = (vecA.x * b) - (vecA.y * a) r = round(detR / det, 3) if 0 <= r <= 1: return True return False def __find_element_with_max_occurrence(self, array): """Find element with most occurrences in the list Args: array (list of str): Array to be searched Returns: str: Value with most occurrences in the list """ temp = {} for i in array: if i in temp: temp[i] += 1 else: temp[i] = 1 max_occur = max(temp.values()) for k, v in temp.items(): if v == max_occur: return k return '' def __get_angle(self, names, vecA, vecB): """Get the size of the angle formed by two bonds which share common atom. Args: names (list of str): List of atom names forming bonds [A, B, C, D] for AB and CD. vecA (Geometry.Point2D): Vector representing AB bond. vecB (Geometry.Point2D): Vector representing CD bond. Returns: float: Size of the angle in degrees. """ pivot = self.__find_element_with_max_occurrence(names) if names[0] != pivot: # Atoms needs to be order to pick vectors correctly vecA = vecA * -1 if names[2] != pivot: vecB = vecB * -1 radians = vecA.AngleTo(vecB) angle = 180 / math.pi * radians return angle
[docs] def has_degenerated_atom_positions(self, threshold): """ Detects whether the structure has a pair or atoms closer to each other than threshold. This can detect structures which may need a template as they can be handled by RDKit correctly. Arguments: threshold (float): Bottom line to use for spatial search. Returns: (bool): if such atomic pair is found """ for i in range(0, len(self.conformer.GetNumAtoms())): center = self.conformer.GetAtomPosition(i) point = [center.x, center.y, center.z] surrounding = self.kd_tree.query_ball_point(point, threshold) if len(surrounding) > 1: return True return False
[docs] def count_suboptimal_atom_positions(self, lowerBound, upperBound): """ Detects whether the structure has a pair or atoms in the range <lowerBound, upperBound> meaning that the depiction could be improved. Arguments: lowerBound (float): lower bound upperBound (float): upper bound Returns: bool: indication whether or not the atoms are not in optimal coordinates """ counter = 0 for i in range(self.conformer.GetNumAtoms()): center = self.conformer.GetAtomPosition(i) point = [center.x, center.y, center.z] surroundingLow = self.kd_tree.query_ball_point(point, lowerBound) surroundingHigh = self.kd_tree.query_ball_point(point, upperBound) if len(surroundingHigh) - len(surroundingLow) > 0: counter += 1 return counter / 2
[docs] def count_bond_collisions(self): """ Counts number of collisions among all bonds. Can be used for estimations of how 'wrong' the depiction is. Returns: int: number of bond collisions per molecule """ errors = 0 for i in range(0, len(self.bonds)): for a in range(i + 1, len(self.bonds)): result = self._intersection(self.bonds[i], self.bonds[a]) if result: errors += 1 return errors
[docs] def has_bond_crossing(self): """ Tells if the structure contains collisions Returns: bool: Indication about bond collisions """ return self.count_bond_collisions() > 0
[docs] def depiction_score(self): """ Calculate quality of the ligand depiction. The higher the worse. Ideally that should be 0. Returns: float: Penalty score. """ collision_penalty = 1 degenerated_penalty = 0.4 bond_collisions = self.count_bond_collisions() degenerated_atoms = self.count_suboptimal_atom_positions(0.0, 0.5) score = collision_penalty * bond_collisions + degenerated_penalty * degenerated_atoms return round(score, 1)