package cjm.chalk;

import java.util.*;

public class Tree {
	
	private Node _root;
	private int[] _planes;
	private int _nodeSize;
	
	public Tree(int[] planes, int nodeSize, double[] planeBisects){
		_planes = planes;
		_nodeSize = nodeSize;
		if(planeBisects != null){
			int index = 0;
			ArrayList<Node> lastNodes = new ArrayList<Node>();
			ArrayList<Node> currentNodes = new ArrayList<Node>();
			for(double bisect : planeBisects){
				double[] point = new double[_planes.length];
				point[index] = bisect;
				if(lastNodes.size() == 0){
					_root = new Node(point);
					lastNodes.add(_root);
				}
				else{
					currentNodes.clear();
					for(Node n : lastNodes){
						n.Left = new Node(point);
						n.Right = new Node(point);
						currentNodes.add(n);
					}
					lastNodes.clear();
					lastNodes.addAll(currentNodes);
				}
				index++;
			}
		}
	}
	
	public void addDataNode(Scan s){
		Node cursor = _root;
		Node trailer = null;
		int depth = 0;
		int index = 0;
		while(cursor != null && !cursor.IsDataNode){
			index = depth % _planes.length;
			trailer = cursor;
			if(s.getProperty(index) < cursor.Point[index]){
				cursor = cursor.Left;
			}
			else{
				cursor = cursor.Right;
			}
			depth++;
		}
		if(cursor != null){
			cursor.add(s, depth);
		}
		else{
			index = (depth -1) % _planes.length;
			if(s.getProperty(index) < trailer.Point[index]){
				trailer.Left = new Node();
				trailer.Left.IsDataNode = true;
				trailer.Left.add(s, depth);
			}
			else{
				trailer.Right = new Node();
				trailer.Right.IsDataNode = true;
				trailer.Right.add(s, depth);
			}
		}
	}	
	
	public Scan[] getCluster(Scan center, int size, double[] weights){
		Scan[] bestNodes = new Scan[size];
		double[] bestErrs = new double[size];
		Arrays.fill(bestErrs, Double.MAX_VALUE);
		int[] visited = new int[] { 0 };
		recurseCluster(center, bestNodes, bestErrs, _root, 0, visited, weights);
		//System.out.println("Visited " + visited[0]);
		return bestNodes;
	}
	
	private void recurseCluster(Scan center, Scan[] bestNodes, double[] bestErrs, Node n, int depth, int[] visited, double[] weights){
		
		visited[0]++;
		double d, err;
		
		if(n.IsDataNode){			
			
			for(int i = 0; i < n.Count; i++){		
				
				Scan s = n.Scans[i];
	
				err = 0;
				for(int index : _planes){
					err += ((d = 
						(
							(center.getProperty(index) - s.getProperty(index)) 
							* weights[index]
						)) * d );
				}

			    int j = bestErrs.length - 1;
			    if(err < bestErrs[j]){
				    while(--j >= 0 && err < bestErrs[j]){
			            bestErrs[j + 1] = bestErrs[j];
			            bestNodes[j + 1] = bestNodes[j];
				    }
				    bestErrs[++j] = err;	
				    bestNodes[j] = s;
			    }			   
			}
			return;
		}
		
		int index = depth % _planes.length;
		if(center.getProperty(index) < n.Point[index]){
			if(n.Left != null){
				recurseCluster(center, bestNodes, bestErrs, n.Left, depth + 1, visited, weights);
			}
			if(n.Right != null){
				err = ((d = ((center.getProperty(index) - n.Point[index]) * weights[index])) * d );
				if(bestErrs[bestErrs.length - 1] > err){
					recurseCluster(center, bestNodes, bestErrs, n.Right, depth + 1, visited, weights);
				}				
			}
		}
		else{
			if(n.Right != null){
				recurseCluster(center, bestNodes, bestErrs, n.Right, depth + 1, visited, weights);
			}
			if(n.Left != null){
				err = ((d = ((center.getProperty(index) - n.Point[index]) * weights[index])) * d );
				if(bestErrs[bestErrs.length - 1] > err){
					recurseCluster(center, bestNodes, bestErrs, n.Left, depth + 1, visited, weights);
				}				
			}
		}
	}
	
	
	class Node{
		public Node Left;
		public Node Right;
		public double[] Point;
		public boolean IsDataNode;
		public Scan[] Scans;
		public int Count = 0;
		
		public Node(){}
		
		public Node(double[] point){
			Point = point;
		}
		
		public void add(Scan s, int depth){		
			
			int index = depth % _planes.length;
			
			if(Scans == null){
				Scans = new Scan[_nodeSize + 1];
			}
			
			//split
			if(Count == _nodeSize){
				
				Scans[_nodeSize] = s;
				double median = Scans[_nodeSize / 2].getProperty(index);
				
				Left = new Node();
				Left.IsDataNode = true;				
				Right = new Node();
				Right.IsDataNode = true;
				
				for(int i = 0; i < Count; i++){
					if(Scans[i].getProperty(index) < median){
						Left.add(Scans[i], depth + 1);
					}
					else{
						Right.add(Scans[i], depth + 1);
					}
				}				
				
				Point = new double[_planes.length];
				Point[index] = median;
				Count = 0;
				IsDataNode = false;
				Scans = null;
			}
			else{ 
				int i = Count;				
				while(--i >= 0 && s.getProperty(index) < Scans[i].getProperty(index)){
					Scans[i + 1] = Scans[i];
				}
				Scans[i + 1] = s;
				
				Count++;
			}			
		}
	}
}
