diff --git a/docs/prediction.md b/docs/prediction.md index 6362de0..bc5c434 100644 --- a/docs/prediction.md +++ b/docs/prediction.md @@ -6,6 +6,10 @@ Once you have installed `boltz`, you can start making predictions by simply runn where `` is a path to the input file or a directory. The input file can either be in fasta (enough for most use cases) or YAML format (for more complex inputs). If you specify a directory, `boltz` will run predictions on each `.yaml` or `.fasta` file in the directory. +The `screen` function allows you to predict interactions for multiple ligands, accepting a single `.sdf` file, a directory of `.sdf` files, or a `.smi` file with ligand IDs and SMILES strings separated by spaces or tabs. Proteins can be provided as `.pdb` or `.fasta` files (recommended). The `screen` function supports all arguments from `predict` and additionally allows specifying a precomputed MSA file (`.m3a`) with `--msa_path`. If no MSA is available, the `--use_msa_server` flag can generate it automatically. **Note**: Using `--use_msa_server` sends data to an external server, and confidentiality cannot be guaranteed. + +`boltz screen --protein --ligands ` + Before diving into more details about the input formats, here are the key differences in what they each support: | Feature | Fasta | YAML | @@ -131,14 +135,14 @@ The following options are available for the `predict` command: After running the model, the generated outputs are organized into the output directory following the structure below: ``` out_dir/ -├── lightning_logs/ # Logs generated during training or evaluation -├── predictions/ # Contains the model's predictions - ├── [input_file1]/ - ├── [input_file1]_model_0.cif # The predicted structure in CIF format +��������� lightning_logs/ # Logs generated during training or evaluation +��������� predictions/ # Contains the model's predictions + ��������� [input_file1]/ + ��������� [input_file1]_model_0.cif # The predicted structure in CIF format ... - └── [input_file1]_model_[diffusion_samples-1].cif # The predicted structure in CIF format - └── [input_file2]/ + ��������� [input_file1]_model_[diffusion_samples-1].cif # The predicted structure in CIF format + ��������� [input_file2]/ ... -└── processed/ # Processed data used during execution +��������� processed/ # Processed data used during execution ``` The `predictions` folder contains a unique folder for each input file. The input folders contain diffusion_samples predictions saved in the output_format. The `processed` folder contains the processed input files that are used by the model during inference. diff --git a/src/boltz/main.py b/src/boltz/main.py index 43bc993..b2f8703 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -1,4 +1,7 @@ +import glob import pickle +import shutil +import string import urllib.request from dataclasses import asdict, dataclass from pathlib import Path @@ -6,9 +9,13 @@ import click import torch +from Bio import SeqIO +from Bio.PDB import PDBParser +from Bio.PDB.Polypeptide import PPBuilder from pytorch_lightning import Trainer from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.utilities import rank_zero_only +from rdkit import Chem from tqdm import tqdm from boltz.data import const @@ -158,7 +165,9 @@ def check_inputs( return data -def compute_msa(data: dict[str, str], msa_dir: Path, msa_server_url:str, msa_pairing_strategy:str) -> list[Path]: +def compute_msa( + data: dict[str, str], msa_dir: Path, msa_server_url: str, msa_pairing_strategy: str +) -> list[Path]: """Compute the MSA for the input data. Parameters @@ -175,7 +184,13 @@ def compute_msa(data: dict[str, str], msa_dir: Path, msa_server_url:str, msa_pai """ # Run MMSeqs2 - msa = run_mmseqs2(list(data.values()), msa_dir, use_pairing=len(data) > 1, host_url=msa_server_url, pairing_strategy=msa_pairing_strategy) + msa = run_mmseqs2( + list(data.values()), + msa_dir, + use_pairing=len(data) > 1, + host_url=msa_server_url, + pairing_strategy=msa_pairing_strategy, + ) # Dump to A3M for idx, key in enumerate(data): @@ -278,7 +293,12 @@ def process_inputs( # noqa: C901, PLR0912, PLR0915 if to_generate: msg = f"Generating MSA for {path} with {len(to_generate)} protein entities." click.echo(msg) - compute_msa(to_generate, msa_dir, msa_server_url=msa_server_url, msa_pairing_strategy=msa_pairing_strategy) + compute_msa( + to_generate, + msa_dir, + msa_server_url=msa_server_url, + msa_pairing_strategy=msa_pairing_strategy, + ) # Parse MSA data msas = {c.msa_id for c in target.record.chains if c.msa_id != -1} @@ -324,91 +344,95 @@ def cli() -> None: return -@cli.command() -@click.argument("data", type=click.Path(exists=True)) -@click.option( - "--out_dir", - type=click.Path(exists=False), - help="The path where to save the predictions.", - default="./", -) -@click.option( - "--cache", - type=click.Path(exists=False), - help="The directory where to download the data and model. Default is ~/.boltz.", - default="~/.boltz", -) -@click.option( - "--checkpoint", - type=click.Path(exists=True), - help="An optional checkpoint, will use the provided Boltz-1 model by default.", - default=None, -) -@click.option( - "--devices", - type=int, - help="The number of devices to use for prediction. Default is 1.", - default=1, -) -@click.option( - "--accelerator", - type=click.Choice(["gpu", "cpu", "tpu"]), - help="The accelerator to use for prediction. Default is gpu.", - default="gpu", -) -@click.option( - "--recycling_steps", - type=int, - help="The number of recycling steps to use for prediction. Default is 3.", - default=3, -) -@click.option( - "--sampling_steps", - type=int, - help="The number of sampling steps to use for prediction. Default is 200.", - default=200, -) -@click.option( - "--diffusion_samples", - type=int, - help="The number of diffusion samples to use for prediction. Default is 1.", - default=1, -) -@click.option( - "--output_format", - type=click.Choice(["pdb", "mmcif"]), - help="The output format to use for the predictions. Default is mmcif.", - default="mmcif", -) -@click.option( - "--num_workers", - type=int, - help="The number of dataloader workers to use for prediction. Default is 2.", - default=2, -) -@click.option( - "--override", - is_flag=True, - help="Whether to override existing found predictions. Default is False.", -) -@click.option( - "--use_msa_server", - is_flag=True, - help="Whether to use the MMSeqs2 server for MSA generation. Default is False.", -) -@click.option( - "--msa_server_url", - type=str, - help="MSA server url. Used only if --use_msa_server is set. ", - default="https://api.colabfold.com", -) -@click.option( - "--msa_pairing_strategy", - type=str, - help="Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete'", - default="greedy", -) -def predict( +# Define a reusable decorator for screen and predict +def shared_options(func): + func = click.option( + "--out_dir", + type=click.Path(exists=False), + default="./", + help="The path where to save the predictions.", + )(func) + func = click.option( + "--cache", + type=click.Path(exists=False), + default="~/.boltz", + help="The directory where to download the data and model. Default is ~/.boltz.", + )(func) + func = click.option( + "--checkpoint", + type=click.Path(exists=True), + default=None, + help="An optional checkpoint, will use the provided Boltz-1 model by default.", + )(func) + func = click.option( + "--devices", + type=int, + default=1, + help="The number of devices to use for prediction. Default is 1.", + )(func) + func = click.option( + "--accelerator", + type=click.Choice(["gpu", "cpu", "tpu"]), + default="gpu", + help="The accelerator to use for prediction. Default is gpu.", + )(func) + func = click.option( + "--recycling_steps", + type=int, + default=3, + help="The number of recycling steps to use for prediction. Default is 3.", + )(func) + func = click.option( + "--sampling_steps", + type=int, + default=200, + help="The number of sampling steps to use for prediction. Default is 200.", + )(func) + func = click.option( + "--diffusion_samples", + type=int, + default=1, + help="The number of diffusion samples to use for prediction. Default is 1.", + )(func) + func = click.option( + "--output_format", + type=click.Choice(["pdb", "mmcif"]), + default="mmcif", + help="The output format to use for the predictions. Default is mmcif.", + )(func) + func = click.option( + "--num_workers", + type=int, + default=2, + help="The number of dataloader workers to use for prediction. Default is 2.", + )(func) + func = click.option( + "--override", + is_flag=True, + help="Whether to override existing found predictions. Default is False.", + )(func) + func = click.option( + "--use_msa_server", + is_flag=True, + help="Whether to use the MMSeqs2 server for MSA generation. Default is False.", + )(func) + func = click.option( + "--msa_server_url", + type=str, + default="https://api.colabfold.com", + help="MSA server url. Used only if --use_msa_server is set.", + )(func) + func = click.option( + "--msa_pairing_strategy", + type=str, + default="greedy", + help="Pairing strategy to use. Options are 'greedy' and 'complete'.", + )(func) + return func + + +# Predict workflow +def predict_input( data: str, out_dir: str, cache: str = "~/.boltz", @@ -533,5 +557,246 @@ def predict( ) +@cli.command() +@click.argument("data", type=click.Path(exists=True)) +@shared_options +def predict( + data: str, + out_dir: str, + cache: str = "~/.boltz", + checkpoint: Optional[str] = None, + devices: int = 1, + accelerator: str = "gpu", + recycling_steps: int = 3, + sampling_steps: int = 200, + diffusion_samples: int = 1, + output_format: Literal["pdb", "mmcif"] = "mmcif", + num_workers: int = 2, + override: bool = False, + use_msa_server: bool = False, + msa_server_url: str = "https://api.colabfold.com", + msa_pairing_strategy: str = "greedy", +) -> None: + predict_input( + data, + out_dir=out_dir, + cache=cache, + checkpoint=checkpoint, + devices=devices, + accelerator=accelerator, + recycling_steps=recycling_steps, + sampling_steps=sampling_steps, + diffusion_samples=diffusion_samples, + output_format=output_format, + num_workers=num_workers, + override=override, + use_msa_server=use_msa_server, + msa_server_url=msa_server_url, + msa_pairing_strategy=msa_pairing_strategy, + ) + + +# small function to parse a complete SDF to smiles +def _process_sdf(sdf_path: str): + output_dict = {} + suppl = Chem.SDMolSupplier(sdf_path) + + for mol in suppl: + if mol is not None: + mol_smiles = Chem.MolToSmiles(mol) + if mol.HasProp("_Name"): + mol_name = mol.GetProp("_Name") + if mol_name == "": + mol_name = mol_smiles + else: + mol_name = mol_smiles + + output_dict[mol_name] = mol_smiles + + return output_dict + + +@cli.command() +@click.option( + "--protein", + type=click.Path(exists=True), + required=True, + help="The path to the PDB or fasta file", +) +@click.option( + "--ligands", + type=click.Path(exists=True), + required=True, + help=( + "Path to the compounds to screen against your protein. This can be either: " + "a directory containing multiple SDF files, a single SDF file with multiple structures, " + "or a text file with compound IDs and their corresponding SMILES strings (in that order)." + ), +) +@click.option( + "--msa_path", + type=click.Path(exists=False), + help="The path to precomputed MSA (should be in m3a format)", + default="", +) +@shared_options +def screen( + protein: str, + ligands: str, + msa_path: str, + out_dir: str, + cache: str = "~/.boltz", + checkpoint: Optional[str] = None, + devices: int = 1, + accelerator: str = "gpu", + recycling_steps: int = 3, + sampling_steps: int = 200, + diffusion_samples: int = 1, + output_format: Literal["pdb", "mmcif"] = "mmcif", + num_workers: int = 2, + override: bool = False, + use_msa_server: bool = False, + msa_server_url: str = "https://api.colabfold.com", + msa_pairing_strategy: str = "greedy", +) -> None: + """Screen many ligands against 1 protein target with Boltz-1.""" + protein_path = Path(protein).expanduser() + ligand_path = Path(ligands).expanduser() + + # Process the protein input + protein_name = protein_path.stem + + if protein_path.suffix.lower() == ".pdb": + # Get FASTA sequence from pdb file using biopython + parser = PDBParser(QUIET=True) + structure = parser.get_structure("protein", protein_path) + ppb = PPBuilder() + + protein_seq = { + chain.id: str(pp.get_sequence()) + for model in structure + for chain in model + for pp in ppb.build_peptides(chain) + } + + elif protein_path.suffix.lower() in (".fa", ".fas", ".fasta"): + # Process as fasta file + with protein_path.open("r") as f: + protein_seq = { + string.ascii_uppercase[i]: str(record.seq) + for i, record in enumerate(SeqIO.parse(protein_path, "fasta")) + } + + else: + msg = f"File format {path.suffix} not supported, please provide file in pdb or fasta format" + raise click.ClickException(msg) + + # Get a list of all the ligand smiles + smiles_dict = {} + + if ligand_path.is_file(): + ligand_name = ligand_path.stem + + # check the extension + if ligand_path.suffix.lower() == ".sdf": + smiles_dict.update(_process_sdf(ligand_path)) + elif ligand_path.suffix.lower() in [".smi", ".ism", ".smiles"]: + # split and add to dict + with open(ligand_path) as ligand_file: + ligand_lines = ligand_file.readlines() + for line in ligand_lines: + line_split = line.strip().split() + smiles_dict[line_split[0]] = line_split[1] + else: + msg = f"Files with {ligand_path.suffix} extension are not supported as ligand. Only .sdf and .smi files are supported" + raise click.ClickException(msg) + + else: + ligand_name = ligand_path.name + + # Get all the sdf files and add them to the dictionary + ligand_files = glob.glob(f"{ligand_path}/*.sdf") + + for ligand_file in ligand_files: + smiles_dict.update(_process_sdf(ligand_file)) + + msg = f"Succesfully identified {len(smiles_dict)} ligands." + click.echo(msg) + + # Generate output directory + out_dir = Path(out_dir).expanduser() + out_dir = out_dir / f"boltz_results_{protein_name}_{ligand_name}" + + # Check if the output directory already exists + if out_dir.exists(): + click.echo(f"The output directory '{out_dir}' already exists.") + if not override: + click.confirm( + "Do you want to delete the existing directory and continue?", + abort=True, + ) + # Delete the directory if confirmed + shutil.rmtree(out_dir) + click.echo() + else: + click.echo("Override flag is set. The existing directory will be used.") + + # Create the output directory + out_dir.mkdir(parents=True, exist_ok=True) + + msa_dir = out_dir / "msa" + msa_dir.mkdir(parents=True, exist_ok=True) + + # Perform the alignment MSA alignment using the FASTA sequence (if use_msa_server is given and no path is given) + if msa_path == "" and use_msa_server == True: + msg = f"Generating MSA for {protein_name}." + click.echo(msg) + compute_msa( + protein_seq, + msa_dir, + msa_server_url=msa_server_url, + msa_pairing_strategy=msa_pairing_strategy, + ) + else: + msg = "No MSA path given, and use_msa_server is set to false. Please generate the MSA manually and include it as an argument, or consider adding --use_msa_server if the protein sequence is not confidential." + raise click.ClickException(msg) + + # Make a directory where we can write all the query files to + query_dir = out_dir / "queries" + query_dir.mkdir(parents=True, exist_ok=True) + + # Make the query template + query_template = "" + + for chain in protein_seq: + query_template += ( + f">{chain}|protein|{msa_dir}/{chain}.a3m\n{protein_seq[chain]}\n" + ) + query_template += f">{string.ascii_uppercase[len(protein_seq)]}|smiles\n" + + for query_id in smiles_dict: + with open(query_dir / f"{query_id}_{protein_name}.fasta", "w") as query_file: + query_file.write(f"{query_template}{smiles_dict[query_id]}\n") + + # Call the predict_input function, where we pass all variables + predict_input( + query_dir, + out_dir=out_dir, + cache=cache, + checkpoint=checkpoint, + devices=devices, + accelerator=accelerator, + recycling_steps=recycling_steps, + sampling_steps=sampling_steps, + diffusion_samples=diffusion_samples, + output_format=output_format, + num_workers=num_workers, + override=override, + use_msa_server=use_msa_server, + msa_server_url=msa_server_url, + msa_pairing_strategy=msa_pairing_strategy, + ) + + if __name__ == "__main__": cli()