| import sys |
| import os |
| import json |
| import copy |
| import numpy as np |
| import scipy.stats as stats |
| import math |
| from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment |
|
|
| def cal_metrics(y_score, y_true): |
| |
| scores = np.expand_dims(y_score, axis=1) |
| y_true = np.expand_dims(y_true, axis=1) |
| scores = np.concatenate((scores, y_true), axis=1) |
| |
| scores = scores[scores[:, 0].argsort()[::-1]] |
| bedroc = CalcBEDROC(scores, 1, 80.5) |
| count = 0 |
| |
| index = np.argsort(y_score)[::-1] |
| for i in range(int(len(index) * 0.005)): |
| if y_true[index[i]] == 1: |
| count += 1 |
| auc = CalcAUC(scores, 1) |
| ef_list = CalcEnrichment(scores, 1, [0.005, 0.01, 0.02, 0.05]) |
|
|
| return { |
| "BEDROC": bedroc, |
| "AUROC": auc, |
| "EF0.5": ef_list[0], |
| "EF1": ef_list[1], |
| "EF5": ef_list[3] |
| } |
|
|
| def print_avg_metric(metric_dict, name): |
| metric_lst = list(metric_dict.values()) |
| ret_metric = copy.deepcopy(metric_lst[0]) |
| for m in metric_lst[1:]: |
| for k in m: |
| ret_metric[k] += m[k] |
|
|
| for k in ret_metric: |
| ret_metric[k] = ret_metric[k] / len(metric_lst) |
| print(name, ret_metric) |
|
|
| def read_zeroshot_res(res_dir): |
| targets = sorted(list(os.listdir(res_dir))) |
| res_dict = {} |
| for target in targets: |
| real_dg = np.load(f"{res_dir}/{target}/saved_labels.npy") |
| if os.path.exists(f"{res_dir}/{target}/saved_preds.npy"): |
| pred_dg = np.load(f"{res_dir}/{target}/saved_preds.npy") |
| else: |
| mol_reps = np.load(f"{res_dir}/{target}/saved_mols_embed.npy") |
| pocket_reps = np.load(f"{res_dir}/{target}/saved_target_embed.npy") |
| res = pocket_reps @ mol_reps.T |
| pred_dg = res.max(axis=0) |
| res_dict[target] = { |
| "pred": pred_dg, |
| "exp": real_dg |
| } |
| return res_dict |
|
|
| def get_ensemble_res(res_list, begin=0, end=-1): |
| if end == -1: |
| end = len(res_list) |
| ret = copy.deepcopy(res_list[begin]) |
| for res in res_list[begin+1:end]: |
| for k in ret.keys(): |
| ret[k]["pred"] = np.array(ret[k]["pred"]) + np.array(res[k]["pred"]) |
|
|
| for k in ret.keys(): |
| ret[k]["pred"] = np.array(ret[k]["pred"]) / (end-begin) |
|
|
| return ret |
|
|
| def avg_metric(metric_lst_all): |
| ret_metric_dict = {} |
| for metric_lst in metric_lst_all: |
| ret_metric = copy.deepcopy(metric_lst[0]) |
| for m in metric_lst[1:]: |
| for k in ["pearsonr", "spearmanr", "r2"]: |
| ret_metric[k] += m[k] |
| for k in ["spearmanr", "pearsonr", "r2"]: |
| ret_metric[k] = ret_metric[k] / len(metric_lst) |
| ret_metric_dict[ret_metric["target"]] = ret_metric |
| return ret_metric_dict |
|
|
| def get_metric(res): |
| metric_dict = {} |
| for k in sorted(list(res.keys())): |
| pred = res[k]["pred"] |
| exp = res[k]["exp"] |
| spearmanr = stats.spearmanr(exp, pred).statistic |
| pearsonr = stats.pearsonr(exp, pred).statistic |
| if math.isnan(pearsonr): |
| pearsonr = 0 |
| if math.isnan(spearmanr): |
| spearmanr = 0 |
| metric_dict[k] = { |
| "pearsonr":pearsonr, |
| "spearmanr":spearmanr, |
| "r2":max(pearsonr, 0)**2, |
| "target":k |
| } |
| return metric_dict |
|
|
|
|
| if __name__ == '__main__': |
| mode = sys.argv[1] |
| if mode == "zeroshot": |
| test_sets = sys.argv[2:] |
| for test_set in test_sets: |
| if test_set in ["DUDE", "PCBA", "DEKOIS"]: |
| metrics = {} |
| target_id_list = sorted(list(os.listdir(f"./result/pocket_ranking/{test_set}"))) |
| for target_id in target_id_list: |
| lig_act = np.load(f"./result/pocket_ranking/{test_set}/{target_id}/saved_labels.npy") |
| score_1 = np.load(f"./result/pocket_ranking/{test_set}/{target_id}/GNN_res_epoch9.npy") |
| score_2 = np.load(f"./result/protein_ranking/{test_set}/{target_id}/GNN_res_epoch9.npy") |
|
|
| score = (score_1 + score_2)/2 |
| metrics[target_id] = cal_metrics(score, lig_act) |
|
|
| json.dump(metrics, open(f"./result/pocket_ranking/{test_set}_metrics.json", "w")) |
| print_avg_metric(metrics, "Ours") |
| elif test_set in ["FEP"]: |
| target_id_list = sorted(list(os.listdir(f"./result/pocket_ranking/{test_set}"))) |
| res_all_pocket, res_all_protein = [], [] |
| for repeat in range(1, 6): |
| res_pocket = read_zeroshot_res(f"./result/pocket_ranking/{test_set}/repeat_{repeat}") |
| res_protein = read_zeroshot_res(f"./result/protein_ranking/{test_set}/repeat_{repeat}") |
| res_all_pocket.append(res_pocket) |
| res_all_protein.append(res_protein) |
| res_all_fusion = get_ensemble_res(res_all_pocket + res_all_protein) |
| metrics = get_metric(res_all_fusion) |
| json.dump(metrics, open(f"./result/pocket_ranking/{test_set}_metrics.json", "w")) |
| print_avg_metric(metrics, "Ours") |
| elif mode == "fewshot": |
| test_set = sys.argv[2] |
| support_num = sys.argv[3] |
| begin = 15 |
| end = 20 |
| metric_fusion_all = [] |
| for seed in range(1, 11): |
| res_repeat_pocket = [] |
| res_repeat_seq = [] |
|
|
| if test_set in ["TIME", "OOD"]: |
| res_file_pocket = f"./result/pocket_ranking/{test_set}/random_{seed}_sup{support_num}.jsonl" |
| res_file_seq = f"./result/pocket_ranking/{test_set}/random_{seed}_sup{support_num}.jsonl" |
| if not os.path.exists(res_file_pocket): |
| continue |
| res_repeat_pocket = [json.loads(line) for line in open(res_file_pocket)][1:] |
| res_repeat_seq = [json.loads(line) for line in open(res_file_seq)][1:] |
| elif test_set in ["FEP_fewshot"]: |
| for repeat in range(1, 6): |
| res_file_pocket = f"./result/pocket_ranking/{test_set}/repeat_{repeat}/random_{seed}_sup{support_num}.jsonl" |
| res_file_seq = f"./result/pocket_ranking/{test_set}/repeat_{repeat}/random_{seed}_sup{support_num}.jsonl" |
| if not os.path.exists(res_file_pocket): |
| continue |
| res_pocket = [json.loads(line) for line in open(res_file_pocket)][1:] |
| res_seq = [json.loads(line) for line in open(res_file_seq)][1:] |
| res_pocket = get_ensemble_res(res_pocket, begin, end) |
| res_seq = get_ensemble_res(res_seq, begin, end) |
| res_repeat_pocket.append(res_pocket) |
| res_repeat_seq.append(res_seq) |
|
|
| res_repeat_fusion = get_ensemble_res(res_repeat_pocket + res_repeat_seq) |
| metric_fusion_all.append(get_metric(res_repeat_fusion)) |
| metric_fusion_all = avg_metric(list(map(list, zip(*metric_fusion_all)))) |
| json.dump(metric_fusion_all, open(f"./result/pocket_ranking/{test_set}_metrics.json", "w")) |
| print_avg_metric(metric_fusion_all, "Ours") |
|
|