#### Author: Mikael Trellet (February 2011)

#### Get center of clusters in function of irmsd and haddock scores. Provide the information  ####
#### about each best first 4 complex of each cluster on the final results graph               ####

import os
from os       import popen, _exit, environ, pathsep
from sys      import argv, stderr, stdout
from math 	  import sqrt
import numpy
import re

num_complex={} # List of 4th best complex for each cluster (num_complex[1] -> (complex_$w.pdb,complex_$w.pdb,complex_$w.pdb,complex_$w.pdb)
nb_cluster=0 # Number of clusters generate by HADDOCK
sd={} # Standard deviation irmsd for each cluster
average={} # Average irmsd for each cluster
sd_h={} # Standard deviation of HADDOCK score for each cluster
average_h={} # Average HADDOCK score for each cluster
sd_vdw={} # Standard deviation of van der Waals
average_vdw={} # Average of van der Waals
sd_elec={} # Standard deviation of electrostatics
average_elec={} # Average of electrostatics
sd_air={} # Standard deviation of AIRs
average_air={} # Average of AIRs
sd_ilrmsd={} # Standard deviation of interface-ligand RMSD
average_ilrmsd={} # Average of interface-ligand RMSD


def getBest4all(cluster): # Get the four best complex number for each clusters
	clust=open('../'+cluster+'_best4') # Uses files generated by HADDOCK (server version)
	num_clust=int(cluster.strip('file.nam_cust'))
	tmp=[]
        n_regex = re.compile('[0-9]+')
	for i, line in enumerate(clust):
                m=re.findall(n_regex,line.split()[0])
                num=str(m[-1])
                tmp.append(num)
                if i == 3:
                        break
	num_complex[num_clust]=tmp
        clust.close()

def getBest4(nb_clust): # Generate files with average values of irmsd for each 4 best complexes
	haddock_best4=open('../clusters.stat_best4')
	for line in haddock_best4.readlines():
		if line.split()[0]!='#Cluster':
			cluster=line.split()[0]
			getBest4all(cluster)
			num_clust=int(cluster.strip('file.nam_cust'))
			average_h[num_clust]=line.split()[1]
			sd_h[num_clust]=line.split()[2]
			average_vdw[num_clust]=line.split()[14]
			sd_vdw[num_clust]=line.split()[15]
			average_elec[num_clust]=line.split()[16]
			sd_elec[num_clust]=line.split()[17]
			average_air[num_clust]=line.split()[18]
			sd_air[num_clust]=line.split()[19]
	haddock_best4.close()
	ilrmsd_values={}
	for i in range (1, nb_clust+1):
		o=open('complex_HS_irmsd_lrmsd_fnat.list','r')
		ilrmsd_values[i]=[]
		for l in o:
			m=l.split()[0]
			if m in num_complex[i]:
				ilrmsd_values[i].append(float(l.split()[8]))
		average_ilrmsd[i]=(sum(ilrmsd_values[i])/4.0)
		sd_ilrmsd[i]=numpy.std(ilrmsd_values[i])
		o.close()

	for c in range(1, nb_clust+1): 
		values_file=open('complex_HS_irmsd_lrmsd_fnat.list')
		value=[]
                average[c]=[0,0,0]
                sd[c]=[0,0,0]
                value.append(0)
                value.append(0)
                value.append(0)
                count=0
		value_sd={}
                value_sd[0]=[]
                value_sd[1]=[]
                value_sd[2]=[]
                value_sd[3]=[]
		for line in values_file.readlines():
			if line.split()[0] in num_complex[c]:
                                value[0]=value[0]+float(line.split()[2])
				value[1]=value[1]+float(line.split()[3])
				value[2]=value[2]+float(line.split()[4])
				value_sd[count].append((float(line.split()[2])))
                                value_sd[count].append((float(line.split()[3])))
                                value_sd[count].append((float(line.split()[4])))
				count=count+1
                values_file.close()
                for i in range(0,3):
                        average[c][0]=value[0]/4
                        average[c][1]=value[1]/4
                        average[c][2]=value[2]/4
		sum_dist=[]
                sum_dist.append(0)
                sum_dist.append(0)
                sum_dist.append(0)
		for d in range (0, 3):
			sum_dist[0]=sum_dist[0]+((value_sd[d][0]-average[c][0])*(value_sd[d][0]-average[c][0]))
			sum_dist[1]=sum_dist[1]+((value_sd[d][1]-average[c][1])*(value_sd[d][1]-average[c][1]))
			sum_dist[2]=sum_dist[2]+((value_sd[d][2]-average[c][2])*(value_sd[d][2]-average[c][2]))
		sd[c][0]=sqrt(sum_dist[0])/2
		sd[c][1]=sqrt(sum_dist[1])/2
		sd[c][2]=sqrt(sum_dist[2])/2
	result_irmsd=''
        result_lrmsd=''
        result_fnat=''
        result_vdw=''
        result_elec=''
        result_air=''
	result_ilrmsd=''
	for i in range(1, nb_clust+1):
		result_irmsd=result_irmsd+(str(average[i][0])+" "+str(average_h[i])+" "+str(sd_h[i])+" "+str(sd[i][0])+"\n")
	plotCI_file=open('plotCI_irmsd.file','w')
	plotCI_file.write(result_irmsd)
	plotCI_file.close()

        for i in range(1, nb_clust+1):
		result_lrmsd=result_lrmsd+(str(average[i][1])+" "+str(average_h[i])+" "+str(sd_h[i])+" "+str(sd[i][1])+"\n")
	plotCI_file=open('plotCI_lrmsd.file','w')
	plotCI_file.write(result_lrmsd)
	plotCI_file.close()

        for i in range(1, nb_clust+1):
		result_fnat=result_fnat+(str(average[i][2])+" "+str(average_h[i])+" "+str(sd_h[i])+" "+str(sd[i][2])+"\n")
	plotCI_file=open('plotCI_fnat.file','w')
	plotCI_file.write(result_fnat)
	plotCI_file.close()

        for i in range(1, nb_clust+1):
		result_vdw=result_vdw+(str(average[i][0])+" "+str(average_vdw[i])+" "+str(sd_vdw[i])+" "+str(sd[i][0])+"\n")
	plotCI_file=open('plotCI_vdw.file','w')
	plotCI_file.write(result_vdw)
	plotCI_file.close()

        for i in range(1, nb_clust+1):
		result_elec=result_elec+(str(average[i][0])+" "+str(average_elec[i])+" "+str(sd_elec[i])+" "+str(sd[i][0])+"\n")
	plotCI_file=open('plotCI_elec.file','w')
	plotCI_file.write(result_elec)
	plotCI_file.close()

        for i in range(1, nb_clust+1):
		result_air=result_air+(str(average[i][0])+" "+str(average_air[i])+" "+str(sd_air[i])+" "+str(sd[i][0])+"\n")
	plotCI_file=open('plotCI_air.file','w')
	plotCI_file.write(result_air)
	plotCI_file.close()

	for i in range(1, nb_clust+1):
		result_ilrmsd=result_ilrmsd+(str(average_ilrmsd[i])+" "+str(average_h[i])+" "+str(sd_h[i])+" "+str(sd_ilrmsd[i])+"\n")
	plotCI_file=open('plotCI_ilrmsd.file','w')
	plotCI_file.write(result_ilrmsd)
	plotCI_file.close()
	
def getSubMatrix():
        o=open('complex_HS_irmsd_lrmsd_fnat.list','r')
        out=open('complex_HS_rmsd_best4.list','w')
        for l in o:
                for key in num_complex:
                        if l.split()[0] in num_complex[key]:
                                out.write(l)
        out.close()
        o.close()

def getClustForR():
        o=open('complex_HS_rmsd_best4.list','r')
        fit=''
        for l in o:
                num=l.split()[0]
                ok=False
                for key in num_complex:
                        if num in num_complex[key]:
                                fit=fit+" "+str(key)
                                ok=True
                                break
                if ok!=True:
                        fit=fit+' 0'
        fit_file=open('clust_R.list','w')
        fit_file.write(fit)
        fit_file.close()

def main(argv):
        nb_clust=int(argv[1])
	getBest4(nb_clust)
#        if nb_clust!=0:
#                getSubMatrix()
#                getClustForR()
        return 0
end = main(argv)
_exit(end)
