| import esm |
| import torch |
|
|
| from Bio import SeqIO |
|
|
| class ESMFold_Pred(): |
| def __init__(self, device): |
| self._folding_model = esm.pretrained.esmfold_v1().eval() |
| self._folding_model.requires_grad_(False) |
| self._folding_model.to(device) |
|
|
| def predict_str(self, pdbfile, save_path, max_seq_len = 1500): |
| seq_record = SeqIO.parse(pdbfile, "pdb-atom") |
| count = 0 |
| seq_list = [] |
| for record in seq_record: |
| seq = str(record.seq) |
| |
|
|
| if len(seq) > max_seq_len: |
| continue |
|
|
| print(f'seq {count}:',seq) |
| seq_list.append(seq) |
| count += 1 |
| |
| for idx, seq in enumerate(seq_list): |
| with torch.no_grad(): |
| output = self._folding_model.infer_pdb(seq) |
| with open(save_path, "w+") as f: |
| f.write(output) |
| break |
|
|
|
|
|
|
|
|