#include <CorrelationTools.h>
#include <debug.h>

/** Keeps stems of minimum length 2 */
CorrelationTools::result_container
CorrelationTools::filterIsolatedCorrelations2(const result_container& correlations, length_type distMin) {
  result_container result;
  set<string> names;
  for(result_container::const_iterator i = correlations.begin(); i != correlations.end(); i++) {
    names.insert((*i).hash());
  }

  for(result_container::const_iterator i = correlations.begin(); i != correlations.end(); i++) {
    const Correlation& corr = *i;
    if ((corr.getStop() - corr.getStart()) < distMin) {
      continue;
    }
    string hash1 = Correlation::hash(corr.getStart() + 1, corr.getStop() - 1);
    if (names.find(hash1) != names.end()) { // name found, is not isolated!
      result.push_back(corr);
    } else {
      string hash2 = Correlation::hash(corr.getStart() - 1, corr.getStop() + 1);
      if (names.find(hash2) != names.end()) { // name found, is not isolated!
	result.push_back(corr);
      }
    }
  }
  return result;
}

/** Keeps stems of minimum length 3 */
CorrelationTools::result_container
CorrelationTools::filterIsolatedCorrelations3(const result_container& correlations, length_type distMin) {
  result_container result;
  set<string> names;
  for(result_container::const_iterator i = correlations.begin(); i != correlations.end(); i++) {
    names.insert((*i).hash());
  }

  for(result_container::const_iterator i = correlations.begin(); i != correlations.end(); i++) {
    const Correlation& corr = *i;
    if ((corr.getStop() - corr.getStart()) < distMin) {
      continue;
    }
    string hash1 = Correlation::hash(corr.getStart() + 1, corr.getStop() - 1);
    string hash2 = Correlation::hash(corr.getStart() + 2, corr.getStop() - 2);
    string hashm1 = Correlation::hash(corr.getStart() - 1, corr.getStop() + 1);
    string hashm2 = Correlation::hash(corr.getStart() - 2, corr.getStop() + 2);
    if (((names.find(hash1) != names.end()) && (names.find(hash2) != names.end()))
	|| ((names.find(hashm1) != names.end()) && (names.find(hash1) != names.end()))
	|| ((names.find(hashm2) != names.end()) && (names.find(hashm1) != names.end()))) {
      result.push_back(corr);
    }
  }
  return result;
}

/** Converts from correlations to vector of stems */
Vec<Stem>
CorrelationTools::convertCorrelationsToStems(const result_container& correlations, length_type distMin, Stem::index_type stemLengthMin) {
  Vec<Stem> result;
  map<string, double> names;
  for(result_container::const_iterator i = correlations.begin(); i != correlations.end(); i++) {
    names.insert(pair<string, double>((*i).hash(), (*i).getScore()));
  }
  for(result_container::const_iterator i = correlations.begin(); i != correlations.end(); i++) {
    const Correlation& corr = *i;
    if ((distMin > 0) && ((corr.getStop() - corr.getStart()) < distMin)) {
      continue;
    }
    string hash1 = Correlation::hash(corr.getStart() + 1, corr.getStop() - 1);
    string hash2 = Correlation::hash(corr.getStart() - 1, corr.getStop() + 1);
    // if ((names.find(hash1) != names.end()) && (names.find(hash2) == names.end())) { // must be beginning of stem
    double highestScore = corr.getScore();
    if (names.find(hash2) == names.end()) { // must be beginning of stem
      Stem::index_type len = 1;
      while (names.find(Correlation::hash(corr.getStart() + len, corr.getStop() - len)) != names.end()) {
	double score = names.find(Correlation::hash(corr.getStart() + len, corr.getStop() - len))->second; // get score, that is density
	if (score > highestScore) {
	  highestScore = score;
	}
	++len;
      }
      if (len >= stemLengthMin) {
 	Stem stem(static_cast<Stem::index_type>(corr.getStart()),
 		  static_cast<Stem::index_type>(corr.getStop()), len);
        stem.setEnergy(highestScore); // highest density: "worst case"
	result.push_back(stem);
      } 
    }
  }
  return result;
}

/** Converts from correlations to vector of FORWARD stems: these mean, that if i,j bind, then i+k,j+k (0 <= k < len) bind too */
Vec<Stem>
CorrelationTools::convertCorrelationsToForwardStems(const result_container& correlations, length_type distMin, Stem::index_type stemLengthMin) {
  Vec<Stem> result;
  map<string, double> names;
  for(result_container::const_iterator i = correlations.begin(); i != correlations.end(); i++) {
    names.insert(pair<string, double>(i->hash(), i->getScore()));
  }
  for(result_container::const_iterator i = correlations.begin(); i != correlations.end(); i++) {
    const Correlation& corr = *i;
    if ((distMin > 0) && ((corr.getStop() - corr.getStart()) < distMin)) {
      continue;
    }
    string hash1 = Correlation::hash(corr.getStart() + 1, corr.getStop() + 1);
    string hash2 = Correlation::hash(corr.getStart() - 1, corr.getStop() - 1);
    // if ((names.find(hash1) != names.end()) && (names.find(hash2) == names.end())) { // must be beginning of stem
    double highestScore = corr.getScore(); // stores density
    if (names.find(hash2) == names.end()) { // must be beginning of stem
      Stem::index_type len = 1;
      while (names.find(Correlation::hash(corr.getStart() + len, corr.getStop() + len)) != names.end()) {
	double score = names.find(Correlation::hash(corr.getStart() + len, corr.getStop() + len)) -> second;
	if (score > highestScore) {
	  highestScore = score;
	}
	++len;
      }
      if (len >= stemLengthMin) {
 	Stem stem(static_cast<Stem::index_type>(corr.getStart()),
 		  static_cast<Stem::index_type>(corr.getStop()), len);
        stem.setEnergy(highestScore);
	result.push_back(stem);
      } 
    }
  }
  return result;
}

Vec<Stem>
CorrelationTools::singleLinkageFilter(length_type start, length_type stop, SingleLinkage2DProgressiveFilter& filter) {
  SingleLinkage2DProgressiveFilter::result_type clusters = filter.push(start, stop);
  Vec<Stem> result;
  for (size_type i = 0; i < clusters.size(); ++i) {
    result_container correlations;
    ASSERT(correlations.size() == 0);
    for (size_type j = 0; j < clusters[i].size(); ++j) {
      Correlation correlation(clusters[i][j].first, clusters[i][j].second);
      correlations.push_back(correlation);
    }
    Vec<Stem> newStems = convertCorrelationsToStems(correlations, 1, 1); // distMin == 1, stemLengthMin == 1
    for (size_type j = 0; j < newStems.size(); ++j) {
      result.push_back(newStems[j]);
    }
  }
  return result;
}

Vec<Stem>
CorrelationTools::singleLinkageFilter(const Vec<Stem>& stems, SingleLinkage2DProgressiveFilter& filter) {
  Vec<Stem> result;
  result_container correlations;
  for (Vec<Stem>::size_type i = 0; i < stems.size(); ++i) {
    for (Stem::index_type j = 0; j < stems[i].getLength(); ++j) {
      length_type start = stems[i].getStart() + j; 
      length_type stop = stems[i].getStop() - j; 
      correlations.push_back(Correlation(start, stop));
    }
  }
  sort(correlations.begin(), correlations.end());
  for (size_type i = 0; i < correlations.size(); ++i) {
    // cout << "Working on filtering correlation " << correlations[i] << endl;
    Vec<Stem> newStems = singleLinkageFilter(correlations[i].getStart(), correlations[i].getStop(), filter);
    for (Vec<Stem>::size_type k = 0; k < newStems.size(); ++k) {
      result.push_back(newStems[k]);
    }
  }
  // flush last stems that are still in buffer by adding dummy correlation
  if (correlations.size() > 0) {
    Vec<Stem> newStems = singleLinkageFilter(correlations[correlations.size()-1].getStart()+100000, 13, filter);
    for (Vec<Stem>::size_type k = 0; k < newStems.size(); ++k) {
      result.push_back(newStems[k]);
    }
    ASSERT(filter.getElementCount() == 1);
    filter.flushAll();
  }
  return result;
}
