| | import os |
| | import subprocess |
| | import warnings |
| | from datetime import datetime |
| | import signal |
| | from contextlib import contextmanager |
| | import numpy as np |
| | import torch |
| | import yaml |
| | from rdkit import Chem |
| | from rdkit.Chem import RemoveHs, MolToPDBFile |
| | from torch_geometric.nn.data_parallel import DataParallel |
| |
|
| | from models.all_atom_score_model import TensorProductScoreModel as AAScoreModel |
| | from models.score_model import TensorProductScoreModel as CGScoreModel |
| | from utils.diffusion_utils import get_timestep_embedding |
| | from spyrmsd import rmsd, molecule |
| |
|
| |
|
| | def get_obrmsd(mol1_path, mol2_path, cache_name=None): |
| | cache_name = datetime.now().strftime('date%d-%m_time%H-%M-%S.%f') if cache_name is None else cache_name |
| | os.makedirs(".openbabel_cache", exist_ok=True) |
| | if not isinstance(mol1_path, str): |
| | MolToPDBFile(mol1_path, '.openbabel_cache/obrmsd_mol1_cache.pdb') |
| | mol1_path = '.openbabel_cache/obrmsd_mol1_cache.pdb' |
| | if not isinstance(mol2_path, str): |
| | MolToPDBFile(mol2_path, '.openbabel_cache/obrmsd_mol2_cache.pdb') |
| | mol2_path = '.openbabel_cache/obrmsd_mol2_cache.pdb' |
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("ignore") |
| | return_code = subprocess.run(f"obrms {mol1_path} {mol2_path} > .openbabel_cache/obrmsd_{cache_name}.rmsd", |
| | shell=True) |
| | print(return_code) |
| | obrms_output = read_strings_from_txt(f".openbabel_cache/obrmsd_{cache_name}.rmsd") |
| | rmsds = [line.split(" ")[-1] for line in obrms_output] |
| | return np.array(rmsds, dtype=np.float) |
| |
|
| |
|
| | def remove_all_hs(mol): |
| | params = Chem.RemoveHsParameters() |
| | params.removeAndTrackIsotopes = True |
| | params.removeDefiningBondStereo = True |
| | params.removeDegreeZero = True |
| | params.removeDummyNeighbors = True |
| | params.removeHigherDegrees = True |
| | params.removeHydrides = True |
| | params.removeInSGroups = True |
| | params.removeIsotopes = True |
| | params.removeMapped = True |
| | params.removeNonimplicit = True |
| | params.removeOnlyHNeighbors = True |
| | params.removeWithQuery = True |
| | params.removeWithWedgedBond = True |
| | return RemoveHs(mol, params) |
| |
|
| |
|
| | def read_strings_from_txt(path): |
| | |
| | with open(path) as file: |
| | lines = file.readlines() |
| | return [line.rstrip() for line in lines] |
| |
|
| |
|
| | def save_yaml_file(path, content): |
| | assert isinstance(path, str), f'path must be a string, got {path} which is a {type(path)}' |
| | content = yaml.dump(data=content) |
| | if '/' in path and os.path.dirname(path) and not os.path.exists(os.path.dirname(path)): |
| | os.makedirs(os.path.dirname(path)) |
| | with open(path, 'w') as f: |
| | f.write(content) |
| |
|
| |
|
| | def get_optimizer_and_scheduler(args, model, scheduler_mode='min'): |
| | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.w_decay) |
| |
|
| | if args.scheduler == 'plateau': |
| | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=scheduler_mode, factor=0.7, |
| | patience=args.scheduler_patience, min_lr=args.lr / 100) |
| | else: |
| | print('No scheduler') |
| | scheduler = None |
| |
|
| | return optimizer, scheduler |
| |
|
| |
|
| | def get_model(args, device, t_to_sigma, no_parallel=False, confidence_mode=False): |
| | if 'all_atoms' in args and args.all_atoms: |
| | model_class = AAScoreModel |
| | else: |
| | model_class = CGScoreModel |
| |
|
| | timestep_emb_func = get_timestep_embedding( |
| | embedding_type=args.embedding_type, |
| | embedding_dim=args.sigma_embed_dim, |
| | embedding_scale=args.embedding_scale) |
| |
|
| | lm_embedding_type = None |
| | if args.esm_embeddings_path is not None: lm_embedding_type = 'esm' |
| |
|
| | model = model_class(t_to_sigma=t_to_sigma, |
| | device=device, |
| | no_torsion=args.no_torsion, |
| | timestep_emb_func=timestep_emb_func, |
| | num_conv_layers=args.num_conv_layers, |
| | lig_max_radius=args.max_radius, |
| | scale_by_sigma=args.scale_by_sigma, |
| | sigma_embed_dim=args.sigma_embed_dim, |
| | ns=args.ns, nv=args.nv, |
| | distance_embed_dim=args.distance_embed_dim, |
| | cross_distance_embed_dim=args.cross_distance_embed_dim, |
| | batch_norm=not args.no_batch_norm, |
| | dropout=args.dropout, |
| | use_second_order_repr=args.use_second_order_repr, |
| | cross_max_distance=args.cross_max_distance, |
| | dynamic_max_cross=args.dynamic_max_cross, |
| | lm_embedding_type=lm_embedding_type, |
| | confidence_mode=confidence_mode, |
| | num_confidence_outputs=len( |
| | args.rmsd_classification_cutoff) + 1 if 'rmsd_classification_cutoff' in args and isinstance( |
| | args.rmsd_classification_cutoff, list) else 1) |
| |
|
| | if device.type == 'cuda' and not no_parallel: |
| | model = DataParallel(model) |
| | model.to(device) |
| | return model |
| |
|
| |
|
| | def get_symmetry_rmsd(mol, coords1, coords2, mol2=None): |
| | with time_limit(10): |
| | mol = molecule.Molecule.from_rdkit(mol) |
| | mol2 = molecule.Molecule.from_rdkit(mol2) if mol2 is not None else mol2 |
| | mol2_atomicnums = mol2.atomicnums if mol2 is not None else mol.atomicnums |
| | mol2_adjacency_matrix = mol2.adjacency_matrix if mol2 is not None else mol.adjacency_matrix |
| | RMSD = rmsd.symmrmsd( |
| | coords1, |
| | coords2, |
| | mol.atomicnums, |
| | mol2_atomicnums, |
| | mol.adjacency_matrix, |
| | mol2_adjacency_matrix, |
| | ) |
| | return RMSD |
| |
|
| |
|
| | class TimeoutException(Exception): pass |
| |
|
| |
|
| | @contextmanager |
| | def time_limit(seconds): |
| | def signal_handler(signum, frame): |
| | raise TimeoutException("Timed out!") |
| |
|
| | signal.signal(signal.SIGALRM, signal_handler) |
| | signal.alarm(seconds) |
| | try: |
| | yield |
| | finally: |
| | signal.alarm(0) |
| |
|
| |
|
| | class ExponentialMovingAverage: |
| | """ from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py |
| | Maintains (exponential) moving average of a set of parameters. """ |
| |
|
| | def __init__(self, parameters, decay, use_num_updates=True): |
| | """ |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; usually the result of |
| | `model.parameters()`. |
| | decay: The exponential decay. |
| | use_num_updates: Whether to use number of updates when computing |
| | averages. |
| | """ |
| | if decay < 0.0 or decay > 1.0: |
| | raise ValueError('Decay must be between 0 and 1') |
| | self.decay = decay |
| | self.num_updates = 0 if use_num_updates else None |
| | self.shadow_params = [p.clone().detach() |
| | for p in parameters if p.requires_grad] |
| | self.collected_params = [] |
| |
|
| | def update(self, parameters): |
| | """ |
| | Update currently maintained parameters. |
| | Call this every time the parameters are updated, such as the result of |
| | the `optimizer.step()` call. |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; usually the same set of |
| | parameters used to initialize this object. |
| | """ |
| | decay = self.decay |
| | if self.num_updates is not None: |
| | self.num_updates += 1 |
| | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) |
| | one_minus_decay = 1.0 - decay |
| | with torch.no_grad(): |
| | parameters = [p for p in parameters if p.requires_grad] |
| | for s_param, param in zip(self.shadow_params, parameters): |
| | s_param.sub_(one_minus_decay * (s_param - param)) |
| |
|
| | def copy_to(self, parameters): |
| | """ |
| | Copy current parameters into given collection of parameters. |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| | updated with the stored moving averages. |
| | """ |
| | parameters = [p for p in parameters if p.requires_grad] |
| | for s_param, param in zip(self.shadow_params, parameters): |
| | if param.requires_grad: |
| | param.data.copy_(s_param.data) |
| |
|
| | def store(self, parameters): |
| | """ |
| | Save the current parameters for restoring later. |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| | temporarily stored. |
| | """ |
| | self.collected_params = [param.clone() for param in parameters] |
| |
|
| | def restore(self, parameters): |
| | """ |
| | Restore the parameters stored with the `store` method. |
| | Useful to validate the model with EMA parameters without affecting the |
| | original optimization process. Store the parameters before the |
| | `copy_to` method. After validation (or model saving), use this to |
| | restore the former parameters. |
| | Args: |
| | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| | updated with the stored parameters. |
| | """ |
| | for c_param, param in zip(self.collected_params, parameters): |
| | param.data.copy_(c_param.data) |
| |
|
| | def state_dict(self): |
| | return dict(decay=self.decay, num_updates=self.num_updates, |
| | shadow_params=self.shadow_params) |
| |
|
| | def load_state_dict(self, state_dict, device): |
| | self.decay = state_dict['decay'] |
| | self.num_updates = state_dict['num_updates'] |
| | self.shadow_params = [tensor.to(device) for tensor in state_dict['shadow_params']] |
| |
|