#!/usr/bin/env python3 # -*- coding: UTF-8 -*- import sys import pandas as pd import pysam if len(sys.argv) != 4: print(" ".join(['usage:python', sys.argv[0], 'filter_file', 'bam_file', 'output'])) sys.exit() # output_dir = sys.argv[1] # tumor = sys.argv[2] # infile = "".join([output_dir, '/mutation/', tumor, '.snp.indel.Somatic.annoall.hg19_multianno_filtered_pre.txt']) # bamfile = "".join([output_dir, '/alignment/', tumor, '.rmdup.bam']) # outfile = "".join([output_dir, '/mutation/', tumor, '.snp.indel.Somatic.annoall.hg19_multianno_filtered.txt']) infile = sys.argv[1] bamfile = sys.argv[2] outfile = sys.argv[3] samfile = pysam.AlignmentFile(bamfile) OUT = open(outfile, 'w') def correct(chr, start, end, alt_base): total_reads = [] alt_reads = [] for pileupcolumn in samfile.pileup(chr, start, end, stepper="samtools", min_base_quality=0, min_mapping_quality=20, max_depth=100000, ignore_overlaps=False, truncate=True, ignore_orphans=False): for pileupread in pileupcolumn.pileups: # print(str(pileupread)) if not pileupread.is_del and not pileupread.is_refskip: if pileupread.alignment.query_sequence[pileupread.query_position] == alt_base: alt_reads.append(pileupread.alignment.query_name) else: # if pileupread.alignment.get_tag('NM') < 4 and pileupread.alignment.query_qualities[ # pileupread.query_position] >= 20: if pileupread.alignment.query_qualities[pileupread.query_position] >= 20: total_reads.append(pileupread.alignment.query_name) alt_reads = list(set(alt_reads)) non_alt_depth = len(list(set(total_reads))) dic = {'chr': [], 'pos': [], 'read': [], 'base': [], 'quality': [], 'NM': [], 'FR_RR': [] } for read in alt_reads: for pileupcolumn in samfile.pileup(chr, start, end, stepper="samtools", min_base_quality=0, min_mapping_quality=20, max_depth=100000, ignore_overlaps=False, truncate=True, ignore_orphans=False): for pileupread in pileupcolumn.pileups: if not pileupread.is_del and not pileupread.is_refskip and pileupread.alignment.query_name == read: dic['chr'].append(pileupcolumn.reference_name) dic['pos'].append(pileupcolumn.reference_pos + 1) dic['read'].append(pileupread.alignment.query_name) dic['base'].append(pileupread.alignment.query_sequence[pileupread.query_position]) dic['quality'].append(pileupread.alignment.query_qualities[pileupread.query_position]) dic['NM'].append(pileupread.alignment.get_tag('NM')) fr = 0 rr = 0 if pileupread.alignment.has_tag('FR'): fr = pileupread.alignment.get_tag('FR') if pileupread.alignment.has_tag('RR'): fr = pileupread.alignment.get_tag('RR') dic['FR_RR'].append(fr + rr) dic = pd.DataFrame(dic) b = alt_reads[:] for R in b: if len(list(dic[dic['read'] == R]['base'])) == 2: if list(dic[dic['read'] == R]['base'])[0] != list(dic[dic['read'] == R]['base'])[1]: print(dic[dic['read'] == R]['base']) print(dic[dic['read'] == R]['quality']) alt_reads.remove(R) else: # if (list(dic[dic['read'] == R]['quality'])[0] >= 20 and list(dic[dic['read'] == R]['NM'])[0] < 4) or ( # list(dic[dic['read'] == R]['quality'])[1] >= 20 and list(dic[dic['read'] == R]['NM'])[1] < 4): if (list(dic[dic['read'] == R]['quality'])[0] >= 20) or ( list(dic[dic['read'] == R]['quality'])[1] >= 20): pass else: alt_reads.remove(R) else: # if list(dic[dic['read'] == R]['quality'])[0] < 20 or list(dic[dic['read'] == R]['NM'])[0] >= 4: if list(dic[dic['read'] == R]['quality'])[0] < 20: alt_reads.remove(R) alt_reads_num = len(alt_reads) total_depth = non_alt_depth + alt_reads_num correct_num = 0 for index, row in dic.iterrows(): if row['read'] in alt_reads: # if row['quality'] >= 20 and row['NM'] < 4 and row['FR_RR'] > 1: if row['quality'] >= 20 and row['FR_RR'] > 1: correct_num += 1 # if alt_reads_num > 3 and correct_num > 0: # return(alt_reads_num) # else: # return(0) return (alt_reads_num, correct_num, total_depth) try: snv = pd.read_table(infile, sep="\t") except pd.errors.EmptyDataError: snv = pd.DataFrame() # cols = [index for index, row in snv.iterrows()] # snv.drop(cols, inplace=True) drop_index = [] for index, row in snv.iterrows(): if len(row['Ref']) == 1 and len(row['Alt']) == 1: # if float(row['Otherinfo13'].split(':')[4]) < 0.05: if float(row['Freq']) < 0.05: c = correct(row['Chr'], row['Start'] - 1, row['End'], row['Alt']) if float(c[0]) < 3 or float(c[1]) < 1 or float(c[0] / (c[0] + c[2])) < 0.002: drop_index.append(index) snv.drop(labels=drop_index, inplace=True) # OUT.write(snv) snv.to_csv(outfile, index=False, sep="\t")