/*
 * bb_pp_bias_correction.cpp
 *
 *  Created on: 1 Nov 2010
 *      Author: jmacdona
 */

#include "bb_pp_bias_correction.h"

using namespace boost;
using namespace PRODART::UTILS;
using namespace PRODART;
using namespace PRODART::POSE;
using namespace PRODART::POSE::POTENTIALS;
using namespace PRODART::POSE::META;
using namespace PRODART::POSE_UTILS;
using namespace std;

namespace PRODART {
namespace POSE {
namespace POTENTIALS {
namespace BB{

bb_pp_bias_correction_shared_ptr new_bb_pp_bias_correction(){
	bb_pp_bias_correction_shared_ptr ptr(new bb_pp_bias_correction());
	return ptr;
}



bb_pp_bias_correction::bb_pp_bias_correction(){
	this->name_vector.clear();
	this->name_vector.push_back(potentials_name("bb_pp_bias_corr"));

	overall_freq_map.resize(grid_divs*grid_divs,0);
	energies.resize(grid_divs*grid_divs,0);
	totalResidueCount = 0;
	phi_psi_increment = (2.0*PI) / static_cast<double>(grid_divs);
	hot_spot_cutoff = 1.6;

}

const int bb_pp_bias_correction::grid_divs = 20;

bool bb_pp_bias_correction::init(){

	const string db_path = PRODART::ENV::get_option_value<string>("database:path:bb_pp_bias_correction");
	std::ifstream inputdb(db_path.c_str(), ios::in);
	this->load_data(inputdb);
	inputdb.close();

	//count_cutoff = PRODART::ENV::get_option_value<int>("potential:bb:bb_pp_bias_correction:count_cutoff");

	return true;
}

void bb_pp_bias_correction::getPhiPsiBin(const int phi_psi_sector, int& phi, int& psi) const{

	/*
	 * 	getPhiGridRef( griddivisions ) +
	 * (griddivisions * getPsiGridRef( griddivisions ));
	 */

	psi = phi_psi_sector / grid_divs;

	phi = (phi_psi_sector - (psi * grid_divs));

	return;


}

double bb_pp_bias_correction::get_lower_bound(int bin) const{
    return (static_cast<double>(bin) *  phi_psi_increment) - PI;
}

double bb_pp_bias_correction::get_grid_centre(int bin) const{
    return (static_cast<double>(bin) *  phi_psi_increment) - PI + (phi_psi_increment/2);
}



double bb_pp_bias_correction::get_energy(const PRODART::POSE::META::pose_meta_shared_ptr pose_meta_,
		potentials_energies_map& energies_map) const{

	const bb_pose_meta_shared_ptr bb_meta_dat = static_pointer_cast<bb_pose_meta, pose_meta_interface>(pose_meta_);
	const const_pose_shared_ptr pose_ = pose_meta_->get_pose();
	double  total_energy = 0;

	const_residue_shared_ptr res_0;

	const int resCount = pose_->get_residue_count();//protein.getResidueCount();
	int phipsiSector = 0;


	for (int i = 0; i < resCount ; i++){

		res_0 = pose_->get_residue(i);


		phipsiSector = bb_meta_dat->get_phi_psi_sector(i,20);


		if (!res_0->is_terminal()
				) {

			total_energy += energies[phipsiSector];
		}
	}

	return energies_map.add_energy_component(name_vector[0], total_energy);
}

double bb_pp_bias_correction::get_energy_with_gradient(const PRODART::POSE::META::pose_meta_shared_ptr pose_meta_,
		potentials_energies_map& energies_map) const{

	return this->get_energy(pose_meta_, energies_map);
}

void bb_pp_bias_correction::get_residue_hot_spots(const PRODART::POSE::META::pose_meta_shared_ptr pose_meta_,
		bool_vector& vec) const{
	const bb_pose_meta_shared_ptr bb_meta_dat = static_pointer_cast<bb_pose_meta, pose_meta_interface>(pose_meta_);
	const const_pose_shared_ptr pose_ = pose_meta_->get_pose();

	const_residue_shared_ptr res_0;

	const int resCount = pose_->get_residue_count();//protein.getResidueCount();
	int phipsiSector = 0;


	for (int i = 0; i < resCount ; i++){

		res_0 = pose_->get_residue(i);


		phipsiSector = bb_meta_dat->get_phi_psi_sector(i,20);


		if (!res_0->is_terminal()
				) {

			if ( energies[phipsiSector] > hot_spot_cutoff) {
				vec[i] = true;
			}
		}
	}


}


std::istream& bb_pp_bias_correction::load_data( std::istream& input ){


    string lineStr;


    long length, lineNum = 0 ;


    string_vector SplitVec;
    while ( !input.eof() ) {
            getline(input, lineStr);
            string resStr;
            lineNum++;

            length = lineStr.length();

            //cout << endl << lineNum << " " << length << " ";

            if (length > 0) {
                    split( SplitVec, lineStr, is_any_of("\t") );
                    if ( SplitVec[0].substr(0,1).compare("#") != 0
                     && SplitVec[0].substr(0,1).compare("") != 0 && SplitVec.size() >= 3 ){
                        string paraName = SplitVec[0];
                        trim(paraName);
                        const int bin = lexical_cast<int>(SplitVec[1]);
                        const double energy = lexical_cast<double>(SplitVec[2]);

                        if ( paraName.compare("PPBC") == 0 ) {

                        	this->energies[bin] = energy;

                        }



                    }

            }

    }

	return input;

}



bool bb_pp_bias_correction::addPseudoCounts( const double sumCount){
	double_vector::iterator iter;

	for (iter = this->overall_freq_map.begin(); iter != this->overall_freq_map.end(); iter++){
		*iter += sumCount;
		this->totalResidueCount += sumCount;
	}

	return true;
}


bool bb_pp_bias_correction::calculateScores( void ){

	const int numBins = overall_freq_map.size();
	energies.resize(numBins, 0);

	const double expected_count = totalResidueCount / static_cast<double>(numBins);

	for (int bin = 0; bin < numBins; bin++){
		this->energies[bin] = -std::log(overall_freq_map[bin] / expected_count);
	}

	return true;


}


void bb_pp_bias_correction::addToDB(const double phi, const double psi){
	if (phi >= - PI && phi <= PI
			&& psi >= - PI && psi <= PI){
		const int phipsiSector = pose_meta_interface::get_phi_psi_sector(phi, psi, grid_divs);
		if ((phipsiSector < (int)overall_freq_map.size())
				&& phipsiSector >=0 ){
			this->overall_freq_map[phipsiSector]++;
			totalResidueCount = totalResidueCount +1;
		}
		else {
			cerr << "ERROR: bin out of range" << endl;
			cerr << phipsiSector << "\t" << phi << "\t" << psi << "\t" << endl;
		}
	}
	else {
		cerr << "ERROR: dihedrals out of range" << endl;
		cerr << phi << "\t" << psi << "\t" << endl;
	}
}


std::ostream& bb_pp_bias_correction::output_phi_psi_info(std::ostream& output){
	/*
	time_t rawtime;
	struct tm * timeinfo;

	time ( &rawtime );
	timeinfo = localtime ( &rawtime );

	cout << "#phi_psi_bias_correction: " << asctime (timeinfo) << "#" << endl;
	 */



	const int numBins = energies.size();

	for (int bin = 0; bin < numBins; bin++){

		int phi = 0, psi = 0;

		getPhiPsiBin(bin, phi, psi );
		double lb_phi = get_lower_bound(phi);
		double lb_psi = get_lower_bound(psi);

		double c_phi = get_grid_centre(phi);
		double c_psi = get_grid_centre(psi);

		output << c_phi * (180.0 / PI) << "\t"
			 << c_psi * (180.0 / PI) << "\t"
			 << energies[bin] << "\t"
			 << overall_freq_map[bin] << "\t"
			 << lb_phi * (180.0 / PI) << "\t"
			 << lb_psi * (180.0 / PI) << "\t"
			 << phi << "\t"
			 << psi << "\t"
			 << bin << "\t"
		     << endl;

	}

	return output;
}



}
}
}
}

