#!/usr/bin/python

### blastflt.py (c) 1999-2003 Arne Mueller

###############################################################################
###
### NAME       : blastflt.py
### TYPE       : program
### PYTHON     : 1.5.2
### VERSION    : 1.1.3
### AUTHOR     : Arne Mueller
### DATE       : Wed Apr 21 BST 1999
### DESCRIPTION:
### The program reades in a blast or psi-blast file, and extracts all non
### redundant hits from all iterations. That means a hit will be treated only
### once, if a hit occures more then once (in different iterations) the one with
### best evalue is taken. Alignments of all HSPs of these hits are listed with
### respect to a reference sequence (the query) in a pseudo multiple sequence
### alignment. Each line of the output conatains a serial number of the alignment,
### 'database_name:hit_name', e-value, score  and %id of alignment, start/end of
### alignment with respect to query sequence and the alignment which is exactly
### the same length as the query including leading and trailing gaps ('-'). Hits
### are sorted by e-value (from low to high). The first line always contains the
### query and legend to explain the meaning of the different columns.
###
### Different commandline options are available e.g for reducing number of
### hit sequences in data strucutre.
###
### USAGE:
###
### blastflt.py [--fasta query_sequence_file.fasta] [--blast file.blast]
###             [--out out_file]
### either query_sequence_file.fasta or file.blast can be replaced by stdin
###
### Type blastflt.py --help
###
### NOTES:
### Only (psi)blast output with pairewise alignments (-m 0) is recognized.
### The number of sequences in summary block has to be equal to number of
### alignments to show (e.g. -b 500 -v 500)
###
### TODO:
### Adujst Class 'MyIteration' so that it can handle iterations where
### number of hits in summary block is different from number of hits in
### alignment (e.g. -b 1000 -v 0)
###
### CHANGES:
###
### CHANGES VERSION 1.1
### DATE: 07.11.99
### o changed commandline options from short to long format
### o added commandline options: minpid, evalue, cluster, minpid, minoverlap,
###   maxoffset, keep, minhsp, dpsep, help - see help text or source code for
###   details.
### o added function 'Thresholds' to to delete data outside given thresholds
###   'evalue' or 'pid' from blast strucuture
### o function 'getHSPlist'. See function's docstring for details. Currently
###   not used (only for development purpose)
### o function 'Cluster' to cluster redundant HSPs and to get the best
###   representative for each cluster. This function reduces the number of Hits
###   and HSPs in blast data structre (see function's docstring)
###
### CHANGES VERSION 1.1.1
### DATE: 14.12.99
### o Removing redundant hits from iterations (in class MyIteration) conflicted
###   with driftfilter (when activated) because redundatn hits removed from an
###   iteration will be recognized as a dirft. Fixed that bug by setting
###   redundant hits to None and rmove these after parsing is complete.
###
### CHANGES VERSION 1.1.2
### DATE: 15.02.2000
### o suport for TBLASTN in blast module. If the blast object is a TBLASTN
###   object, the translation frame is reported for each HSP
###
### DATE: 03.03.2000
### o small bug fix on option --keep, changed regular default regex from
###   '\|(pdb)|(scop)\|' to '\|(pdb|scop)\|' (which is the 'more' correct
###   form.
###
### CHANGES VERSION 1.1.3
### DATE: 29.03.2000
###
### o changed defaults for options --blast_file and --query_file to sys.stdin
###   (real file streams instead of names). Note you can still provide strings
###   to these options, these are then assumed to file names and opened in FIO.
###
### DATE: 01.05.2000
###
### o fixed a bug in 'getQuery()' when reading fasta files that don't terminate
###   with '\n'. This change is necessary due to a change in class fio.FIO
###   ('getLine()' returns None instrad of '' when trying to fill the read
###   buffer when EOF is reached!).
###
### DATE: 29.01.2000
###
### o new switch 'nostacked' to turn off stacking of pair alignments
###   to see all changes look for the string [SH290101]
### DATE: 17.08.2001
###
### o Merged in changes from Stephan Herschel (sh) from 29.01.2000. New switch
###   'nostacked' to turn off stacking of pair alignments. This is usefull if
###   you need to get the real alignments between query and subject, gaps in
###   the query are shown.
###
### CHANGES VERSION 1.1.4
###
### DATE 17.07.2003
### 
### o added option '--descrip' to append description line to output
###############################################################################

import sys
import string
import types
import getopt
import time
from fio import FIO
from re import *
from blast import *
from misc import AlignmentContainer

###
### Classes (new and redefined)
###

###############################################################################

class MyHit(Hit):

    """Representation of a blast hit

       TYPE       : class
       CONSTRUCTOR: No
       DESCRIPTION:
       Specific class to modify behaviour of superclass Hit. This
       specification skips the full text description for each hit.
    """

    def parseAlignment(self, token):
        if not self.id:
            self.id = token.m.group(1)
        #if not (self.db and self.name):
        #    self.db, self.name = self.getDbNamefromId()
        self.parse()

###############################################################################

class MyIteration(Iteration):

    """Representation of a blast Iteration

       TYPE       : class
       CONSTRUCTOR: needs a blast 'Hit' class as argument and calls
                    base class constructor
       DESCRIPTION:
       This class doesn't read in all hits of an iteration. Hits found
       in a previous iteration are ignored or replace the previous ones
       if their e-value is (better) smaller. Read hits are stored in
       dictionary 'hitlist'.

       NOTE: This class works only when number of hits in summary block
       == number of hits in alignment block
    """
    def __init__(self, blastobj):
        self.new = 1
        Iteration.__init__(self, blastobj)

    def parseAlignments(self, token):
        "Redefinition of base class method."
        Iteration.parseAlignments(self, token)
        self.reduceHits()

    def reduceHits(self):
        "Update hitlist"
        ### NOTE: hits shouldn't be deleted from the blast object, because
        ### the drift-filter of the blast object would complain a drift!
        ### workaround: all redundant hits are set to None and can be deleted
        ### after parsing is complete (outside this class!)
        for i in self.hits.keys():
            insert = 1
            self.hits[i].itref = self
            if hitlist.has_key(i) and hitlist[i].e < self.hits[i].e:
                self.hits[i] = None
                insert = 0
            elif hitlist.has_key(i):
                hitlist[i].itref.hits[i] = None
            if insert:
                hitlist[i] = self.hits[i]

###############################################################################

class StreamCP:

    """Copies io-stream 'src' to 'dst', can be used to copy
       stderr to a log file (like unix tee)
    """

    def __init__(self, src=sys.stderr, dst=sys.stdout):
        self.src = src
        self.dst = dst

    def write(self, msg=''):
        self.src.write(msg)
        self.dst.write(msg)

###############################################################################

###
### Unbound functions
###

###
def printAll(blast):
    "Debugging routine"
    for i in blast.iter:
        print "********************"
        print "ITERATION : %s\n" % i.it,
        print "********************"
        for k in i.hits.keys():
            print '\nHIT: %s, length =  %d, status = %s' % (
                i.hits[k].id, i.hits[k].length, i.hits[k].status)
            print 'DESCRIPTION: %s' % i.hits[k].descr
            for hsp in  i.hits[k].hsp:
                print 'HSP: score = %f, e = %f, identity = %d, q_start = %d,\
                q_end = %d,  s_start = %d, s_end = %d'% (
                    hsp.score, hsp.e, hsp.identity,
                    hsp.q_start, hsp.q_end, hsp.s_start, hsp.s_end)
                print 'QUERY: %s' % hsp.query
                print 'SBJCT: %s' % hsp. subject

### operates on a directly 'unsorted' blast object
def stackHits(blast, qseq, qname):
    "Alternative output routine for blastflt.py"    
    m = 1
    seq = ''
    str = '    NR  %-22s E            SCORE   ID  STRT  END  %s' % (qname, qseq)
    print str
    for i in blast.iter:
        for k in i.hits.keys():
            n = 1
            for hsp in i.hits[k].hsp:
                end = hsp.end
                seq = align(qseq, hsp.query, hsp.subject, hsp.q_start, hsp.q_end)
                str = '%4d.%-2d' % (m, n)
                n = n + 1
                str = str + ' %-22s' % i.hits[k].id
                str = str + ' %-#12.3e %-7.2f %-3d' % (hsp.e, hsp.score, hsp.identity)
                str = str + ' %-4d %-4d' % (hsp.start, hsp.end)
                str = str + ' %s' % seq
                print str
            m = m + 1

###############################################################################
def nostackListedHits(list, qname):
    # added by sh 29.01.2001
    "Alternative output routine for blastflt.py"
    m = 1
    seq = ''
    klist = list.keys()
    klist.sort(sortByE)
    for k in klist:
        n = 1
        for hsp in list[k].hsp:
            if b.type == 'TBLASTN':
                str1 = 'FRAME ';
            else:
                str1 = ''
            str1 = '    NR  %-38s %sE            SCORE   ID  QSTRT QEND  SSTRT SEND   %s' % (
                  qname, str1, hsp.query)
            print str1
            end = hsp.end
            seq = hsp.subject
            str = '%4d.%-2d' % (m, n)
            n = n + 1
            str = str + ' %-38s' % list[k].id
            if hsp.__dict__.has_key('frame'):
                str = str + ' %-5s' % hsp.frame
            str = str + ' %-#12.3e %-7.2f %-3d' % (hsp.e, hsp.score, hsp.identity)
            str = str + ' %-5d %-5d %-5d %-5d ' % (
                hsp.q_start, hsp.q_end, hsp.s_start, hsp.s_end)
            str = str + ' %s' % seq
            if descrip:
                str = str + list[k].descr
            print str
        m = m + 1

### operates on a dictionary containing 'hits', list will be sorted
def stackListedHits(list, qseq, qname):
    "Current output routine for blastflt.py"
    m = 1
    seq = ''
    if b.type == 'TBLASTN':
        str = 'FRAME ';
    else:
        str = ''
    str = '    NR  %-38s %sE            SCORE   ID  QSTRT QEND  SSTRT SEND   %s' % (
        qname, str, qseq)
    print str
    klist = list.keys()
    klist.sort(sortByE)
    for k in klist:
        n = 1
        for hsp in list[k].hsp:
            end = hsp.end
            seq = align(qseq, hsp.query, hsp.subject, hsp.start, hsp.end)
            str = '%4d.%-2d' % (m, n)
            n = n + 1
            str = str + ' %-38s' % list[k].id
            if hsp.__dict__.has_key('frame'):
                str = str + ' %-5s' % hsp.frame
            str = str + ' %-#12.3e %-7.2f %-3d' % (hsp.e, hsp.score, hsp.identity)
            str = str + ' %-5d %-5d %-5d %-5d ' % (
                hsp.q_start, hsp.q_end, hsp.s_start, hsp.s_end)
            str = str + ' %s' % seq
            if descrip:
                str = str + list[k].descr
            print str
        m = m + 1

###
def getQuery(file):
    "Read in fasta formated sequence"
    query_seq  = ''
    query_name = ''
    f = FIO(file, 'r')
    l = f.getLine()
    m = search('^>(\S+)', l)
    if not m: f.ioERR('parsing query')
    query_name = m.group(1)
    l = f.getLine()
    while( l ):
        m = search('^([a-zA-Z\-\?\.]+)', l)
        if not m: break
        query_seq = query_seq + m.group(1)
        l = f.getLine()
    query_seq  = sub('\s+', '', query_seq)
    query_seq  = sub('(\*)', '', query_seq)
    return (query_name, query_seq)

###
def align(master, query, subject, start, end):
    return map(master, replace(query, subject), start, end)

###
def replace(query, subject):
    qgaps = findGaps(query)
    query = sub('\-+', '', query)
    slist = []
    for i in subject:
        slist.append(string.upper(i))
    for i in qgaps:
        slist[i[0]-1] = string.lower(slist[i[0]-1])
        slist[i[1]+1] = string.lower(slist[i[1]+1])
        slist[i[0]:i[1]+1] = range(i[0], i[1]+1)
    result = ''
    for i in slist:
        if (type(i) == types.StringType): result = result + i
    return result

###
def map(master, subject, start, end):
    return (start - 1) * '-' + subject + (len(master) - end) * '-'

###
def findGaps(seq):
    gaps = []
    i = 0
    n = len(seq)-1
    while( i <= n ):
        if seq[i] == '-':
            pos = i
            while( i <= n and seq[i] == '-'): i = i + 1
            gaps.append((pos, i-1))
        else:
            i = i + 1
    return gaps

###
def sortByE(a, b):
    if hitlist[a].e < hitlist[b].e: return -1
    if hitlist[a].e > hitlist[b].e: return  1

    if hitlist[a].score < hitlist[b].score: return  1
    if hitlist[a].score > hitlist[b].score: return -1

    if hitlist[a].hsp and hitlist[b].hsp:
        if hitlist[a].hsp[0].identity < hitlist[b].hsp[0].identity: return  1
        if hitlist[a].hsp[0].identity > hitlist[b].hsp[0].identity: return -1

    return 0

###
def getHSPlist(hits):
    """Creates a list of all HSPs that are in 'hits'.
       All attributes of the hit are copied to the
       corresponding HSP without overwriting any existing
       attributes! NOTE, this changes the original HSPs
       of the blast data structure!

       'hits' is a dictionary of references to Hits.
       Returns a list with references to HSPs

       Returns the list of all HSPs in all hits.

       NOTE: modifies 'hits' and data structure 'hits' referes to!
    """
    hsp_list = []
    for hit in hits.values():
        attributes =  hit.__dict__
        for hsp in hit.hsp:
            for name in attributes.keys():
                ### copy attributes from hit to hsp
                if not hsp.__dict__.has_key(name):
                    hsp.__dict__[name] = attributes[name]
            hsp_list.append(hsp)
    return hsp_list

###
def Thresholds(list):
    """Removes all HSPs and Hits from hitlist and entire blast
       datastructure that are outside Thresholds (evalue and pid)

       NOTE: call by reference function, modifies 'list' and data structure
       'list' referes to!
    """
    rm = [0,0,0]
    for id in list.keys():
        i = 0
        rm[2] = rm[2] + len(list[id].hsp)
        while i < len(list[id].hsp):
            if (evalue != None and list[id].hsp[i].e > evalue) or (
                pid and list[id].hsp[i].identity < pid):
                del list[id].hsp[i]
                rm[0] = rm[0] + 1
            else:
                i = i + 1
        if (evalue != None and list[id].e > evalue) or not list[id].hsp:
            del list[id].itref.hits[id] # remove from blast data structure
            del list[id]                # remove from dictionary with hits
            rm[1] = rm[1] + 1
    return rm

###
def Cluster(list):
    """Clusters redundant HSPs. Clustering is controlled by
       commandline parameters minpid, minoffset, minoverlap.

       All HSPs in all Hits (of dictionary 'list') are considered
       for the clustering. HSPs belonging to a database type in
       'keep' list are excluded from clustering.

       From each cluster the best HSP is choosen as a representative
       (see module 'misc' class 'AlignmentContainer'). All other HSPs
       are removed from 'list' and the entire blast data structure.

       Returns 4 elements list:
       0. dictionary with all non redundant hits and HSPs
       1. number of sequences processed (initial number of sequences)
       2. number of clusters (number of non redundant HSPs)
       3. number of non redundant hits (one hit can have several HSPs!)

       NOTE: call by reference function, modifies 'list' and data structure
       'list' referes to!
    """
    remain = []

    ### adjust some hsp members to the requirements
    ### of the 'AlignmentContainer' and create the
    ### containers for all 'dbs'
    nseq     = 0  # count sequences
    clusters = [] # clusters
    for hit in list.keys():
        for hsp in list[hit].hsp:
            hsp.best     = 0
            ### adjust the hsp object
            hsp.evalue   = hsp.e
            hsp.pid      = hsp.identity
            hsp.sequence = hsp.subject
            hsp.length   = len(hsp.sequence)
            ### add the number of gaps in the query to query length
            ### and recalculate q_end
            hsp.old_end = hsp.q_end
            hsp.q_end = hsp.q_end + string.count(hsp.query, '-')
            ### fill the lists for each database type in 'dbs'
            descr = list[hit].id + ' ' +  list[hit].descr
            if findall(keep, descr):
                clusters.append([hsp])
            else:
                remain.append(hsp)
            nseq = nseq + 1
    ### create AlignmentContainer and cluster all redundant
    ### remaining hits
    container = AlignmentContainer(remain)
    container.min_id      = minpid
    container.max_offset  = maxoffset
    container.min_overlap = minoverlap
    clusters = clusters + container.clusterHits()

    ### keep all first list elements of
    ### cluster list (best) and set the
    ### 'cluster' flag of the hsp object to 1
    for n in range(len(clusters)):
        clusters[n] = clusters[n][0]
        clusters[n].best = 1

    ### delete all hsp and hit objects from entire
    ### blast data structure and from list.
    hits = 0
    for hit_name in list.keys():
        i = 0
        while i < len(list[hit_name].hsp):
            if not list[hit_name].hsp[i].best:
                del list[hit_name].hsp[i]
            else:
                list[hit_name].hsp[i].q_end = list[hit_name].hsp[i].old_end
                i = i + 1
        if not list[hit_name].hsp:
            del list[hit_name].itref.hits[hit_name]
            del list[hit_name]
        else:
            hits = hits + 1
    return [list, nseq, len(clusters), hits]

###
def help():
    print """ blastflt.py version 1.1 (c) 1999-2003 Arne Mueller
    -   Filtering of GapBLAST/PSI-BLAST output   -

    usage:
    blastflt.py --fasta file | --blast file [options]

    either --fasta and/or --blast has to be specified, only
    one of these files can be stdin. The type of file comming
    from stdin is determined from the other necessery file
    option on the commandline.

    options:
    --fasta file query sequence in fasta formated file
    --blast file GapBLAST/PSI-BLAST file to be processed
    --out file name of the outputfile
    THRESHOLDS:
    --evalue float e-value cutoff for sequences included
      (default = off)
    --pid integer percent sequence identity cutoff of hit
      sequence with query as calculated by BLAST (default = off)
    --driftchk float value for basic drift check. If sequeneces with
      evalue <= driftchk of the first iteration in PSI-BLAST mode are
      not found in alater iteration the blastfilter stops one iteration
      before the drift (default = None, no check for drift). A resonable
      value is the h cutoff choosen for a PSI-BLAST run.
    CLUSTERING:
    --cluster triggers clustering of redundant hits (default = off)
    --minhsp integer minimmal number of HSPs for clustering, for
      queryies with less HSPs no clustering will we done (default = 50)
    --minoverlap integer minimum percent overlap between two
      sequences as measured by the longer sequence to join the
      same cluster (default = 90)
    --maxoffset integer maximum allowed N-terminal or C-terminal
      offset in residues between two sequences to join the same
      cluster (default = 10)
    --minpid integer minimal percent sequence identity between two
      sequences to join the same cluster (default = 80). Only the
      overlapping part of the two sequences is considered for the
      calculation. Positions where both sequences have gaps or non
      sequence characters are not included in the calculation.
    --keep regular expression to identify sequences belonging to
      a certain database (e.g. pdb). The description line and sequence
      id is parsed for the regular expression. Matching sequneces
      are excluded from clustering
      (Default = '\|(pdb|scop)\|')
    --nostacked: turn off stacking of listed hits
      pairwise alignments with gapped query and subject are shown.
    --descrip: append description line of hit to each output line
      (appended after sequence)

    example, the default parameters with clustering are:
    blastflt.py --fasta MG001.fasta --blast MG001.psiblast
                --cluster --minpid 80 --maxoffset 10
                --minoverlap 90 --keep '\|(pdb|scop)\|'

    ...skipping all hits/HSPs with e-value > 0.5 or %%id < 20:
    blastflt.py --fasta MG001.fasta --blast MG001.psiblast
                --evalue 0.5 --pid 20

    """
    sys.exit(0)

###############################################################################

###
### Main program
###

### global variables and defaults associated with commandline options
query_file   = sys.stdin          # file
out_file     = ''                 # file name
blast_file   = sys.stdin          # file
evalue       = None               # float
pid          = 0                  # %
driftchk     = None               # float
cluster      = 0                  # boolean
keep         = compile('\|(pdb|scop)\|')   # regex to catch database types
minpid       = 80                 # %
maxoffset    = 10                 # residues
minoverlap   = 90                 # %
minhsp       = 50                 # minimal number of HSPs to trigger clustering
nostacked    = 0                  # 0 ... stacked output, 1 ... non stacked output
descrip      = 0                  # 0|1
#-------------------------------------

### other global variables
hitlist    = {}
query_name = ''
query_seq  = ''

### logging, stderr is copied to a logfile
log_file = open('./blastflt.log', 'a+')
log = StreamCP(sys.stderr, log_file)
sys.stderr = log

### process commandline arguments
### options
long_opt = ['fasta=',        # query in a fasta formated file
            'blast=',        # blast/psiblast output to be processed
            'evalue=',       # evalue cuttoff for including hits
            'pid=',          # percent sequence id cutoff for including hits
            'out=',          # output file
            'driftchk=',     # check for profile drift
            'cluster',       # cluster redundant hits
            'keep=',         # databases not to be clustered (--keep scop pdb)
            'minpid=',       # minimum percent sequence id for clustering
            'maxoffset=',    # maximum N or C term residue offset of for clustering
            'minoverlap=',   # minimum overlap of sequences for clustering
            'minhsp=',       # minimal number of HSPs to trigger clustering
            'help',          # help text and usage
            'nostacked',     # stacked output yes/no
            'descrip'        # append description line for hit to stacked output
            ]       

### process commandline options
args = sys.argv[1:]
optlist, rest = getopt.getopt(args, '', long_opt)
for i in optlist:
    if   i[0] == '--help'      : help()
    elif i[0] == '--fasta'     : query_file   = i[1]
    elif i[0] == '--blast'     : blast_file   = i[1]
    elif i[0] == '--out'       : out_file     = i[1]
    elif i[0] == '--dbsep'     : Parser.dbsep = i[1]
    elif i[0] == '--evalue'    : evalue       = string.atof(i[1])
    elif i[0] == '--pid'       : pid          = string.atoi(i[1])
    elif i[0] == '--driftchk'  : driftchk     = string.atof(i[1])
    elif i[0] == '--cluster'   : cluster      = 1
    elif i[0] == '--keep'      : keep         = compile(i[1])
    elif i[0] == '--minpid'    : minpid       = string.atoi(i[1])
    elif i[0] == '--maxoffset' : maxoffset    = string.atoi(i[1])
    elif i[0] == '--minoverlap': minoverlap   = string.atoi(i[1])
    elif i[0] == '--minhsp'    : minhsp       = string.atoi(i[1])
    elif i[0] == '--nostacked' : nostacked    = 1
    elif i[0] == '--descrip'   : descrip      = 1

if query_file == sys.stdin and blast_file == sys.stdin:
    sys.stderr.write("""
          You cannot specify 'fasta' and 'blast' file as stdin,
          one has to be a real filename!\n""")
    sys.exit(1)

if out_file:
    out = open(out_file, 'w')
    sys.stdout = out

### file reading
### parsing/generation of blast file/object
f = FIO(blast_file, 'r')
sys.stderr.write('parsing blast file %s ...\n' % blast_file)
b = Blast(f, iteration=MyIteration)
if driftchk == None:
    b.drift = None
else:
    b.driftcutoff = driftchk
b.createBlastTree()
f.close()

### delete all hits that are set to None (reduandant hits,
### see class MyIteration)
for i in b.iter:
    for k in i.hits.keys():
        if i.hits[k] == None:
            del i.hits[k]

### post process hitlist
### get rid of hits below thresholds (pid, evalue)
if evalue != None or pid:
    hsp, hit, init = Thresholds(hitlist)
    sys.stderr.write('%s: HSPs total = %d, HSPs outside threshold = %d (%d Hits)\n'
                     %(b.qname, init, hsp, hit))

### cluster redundant hits
n_hsp = 0
for hit in hitlist.values():
    n_hsp = n_hsp + len(hit.hsp)
if cluster and n_hsp >= minhsp:
    sys.stderr.write('%s: clustering redundant hits ...\n' % b.qname)
    x = time.clock()
    hitlist, nseq, ncluster, hits = Cluster(hitlist)
    sys.stderr.write('%s: sequences = %d, clusters = %d, hits = %d\n' %(
        b.qname, nseq, ncluster, hits))
    sys.stderr.write('cpu time for clustering: %g sec\n' % (time.clock() - x))

### parse query sequence
### generation of stacked alignment (mview like 'new' format)
query_name, query_seq = getQuery(query_file)
sys.stderr.write('generating output ...\n')
# stackHits(b, query_seq, query_name)
if nostacked:
    nostackListedHits(hitlist, query_name)
else:
    stackListedHits(hitlist, query_seq, query_name)
