cheshirekow  v0.1.0
tree.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2012 Josh Bialkowski (jbialk@mit.edu)
3  *
4  * This file is part of mpblocks.
5  *
6  * mpblocks is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * mpblocks is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with mpblocks. If not, see <http://www.gnu.org/licenses/>.
18  */
27 #ifndef MPBLOCKS_BTPS_TREE_H_
28 #define MPBLOCKS_BTPS_TREE_H_
29 
30 #include <cassert>
31 #include <algorithm>
32 #include <map>
33 #include <list>
34 
35 //#include <iostream>
36 //#include <boost/format.hpp>
37 
38 namespace mpblocks {
39 namespace btps {
40 
43 template <class Traits>
44 class Tree {
45  public:
46  typedef typename Traits::NodeRef NodeRef;
47  typedef typename Traits::NodeOps NodeOps;
48 
49  private:
53 
54  NodeRef& parent(NodeRef N) { return ops_.Parent(N); }
55  NodeRef& left(NodeRef N) { return ops_.LeftChild(N); }
56  NodeRef& right(NodeRef N) { return ops_.RightChild(N); }
57 
58  auto count(NodeRef N) -> decltype(ops_.Count(N)) {
59  return ops_.Count(N);
60  }
61 
62  auto cum(NodeRef N) -> decltype(ops_.CumulativeWeight(N)) {
63  return ops_.CumulativeWeight(N);
64  }
65 
66  auto weight(NodeRef N) -> decltype(ops_.Weight(N)) {
67  return ops_.Weight(N);
68  }
69 
70  public:
72 
85  Tree(NodeRef Nil) : nil_(Nil), root_(Nil) {
86  cum(nil_) = 0;
87  count(nil_) = 0;
88  parent(nil_) = nil_;
89  left(nil_) = nil_;
90  right(nil_) = nil_;
91  }
92 
94  NodeRef root() { return root_; }
95 
97  auto sum() -> decltype(ops_.CumulativeWeight(root_)) {
98  return ops_.CumulativeWeight(root_);
99  }
100 
102  void clear() { root_ = nil_; }
103 
105  size_t size() { return count(root_); }
106 
107  private:
111  void update(NodeRef x) {
112  cum(x) = weight(x) + cum(left(x)) + cum(right(x));
113  count(x) = 1 + count(left(x)) + count(right(x));
114  }
115 
117  void insert(NodeRef p, NodeRef x) {
118  assert(p != nil_);
119 
120  // if either child is empty then we can fill that slot
121  if (left(p) == nil_) {
122  left(p) = x;
123  parent(x) = p;
124  } else if (right(p) == nil_) {
125  right(p) = x;
126  parent(x) = p;
127  }
128 
129  // otherwise find the child with the smallest progeny
130  else {
131  if (count(left(p)) < count(right(p)))
132  insert(left(p), x);
133  else
134  insert(right(p), x);
135  }
136 
137  update(p);
138  }
139 
143  if (p == nil_)
144  return;
145  else
146  update(p);
147 
148  if (p == root_)
149  return;
150  else
152  }
153 
156  if (count(p) == 1) return p;
157  if (count(left(p)) > count(right(p)))
158  return deepestLeaf(left(p));
159  else
160  return deepestLeaf(right(p));
161  }
162 
164  bool isLeaf(NodeRef x) { return (left(x) == nil_ && right(x) == nil_); }
165 
167  void swap(NodeRef a, NodeRef b) {
168 // typedef boost::format fmt;
169 // std::cout << "before swap:\n"
170 // << fmt("ref: %16u | %16u \n") % a % b
171 // << fmt("parent: %16u | %16u \n") % parent(a) % parent(b)
172 // << fmt("left: %16u | %16u \n") % left(a) % left(b)
173 // << fmt("right: %16u | %16u \n") % right(a) % right(b)
174 // << fmt("count: %16u | %16u \n") % count(a) % count(b);
175 
176  if (parent(a) != nil_) {
177  if (left(parent(a)) == a)
178  left(parent(a)) = b;
179  else {
180  assert(right(parent(a)) == a);
181  right(parent(a)) = b;
182  }
183  }
184 
185  if (parent(b) != nil_) {
186  if (left(parent(b)) == b)
187  left(parent(b)) = a;
188  else {
189  assert(right(parent(b)) == b);
190  right(parent(b)) = a;
191  }
192  }
193 
194  std::swap(parent(a), parent(b));
195  std::swap(left(a), left(b));
196  std::swap(right(a), right(b));
197  std::swap(count(a), count(b));
198 
199  if (left(a) != nil_) parent(left(a)) = a;
200  if (right(a) != nil_) parent(right(a)) = a;
201  if (left(b) != nil_) parent(left(b)) = b;
202  if (right(b) != nil_) parent(right(b)) = b;
203 //
204 // std::cout << "after swap:\n"
205 // << fmt("ref: %16u | %16u \n") % a % b
206 // << fmt("parent: %16u | %16u \n") % parent(a) % parent(b)
207 // << fmt("left: %16u | %16u \n") % left(a) % left(b)
208 // << fmt("right: %16u | %16u \n") % right(a) % right(b)
209 // << fmt("count: %16u | %16u \n") % count(a) % count(b);
210  }
211 
214  assert(isLeaf(x));
215  NodeRef p = parent(x);
216 
217  // if x doesn't have a parent then it must be the root so
218  // clear out the tree
219  if (p == nil_) {
220  assert(x == root_);
221 
222  // since x is the root and it is also a leaf then we
223  // must empty the tree
224  root_ = nil_;
225  }
226  // if x has a parent, then remove x from it's list of children
227  else {
228  if (left(p) == x)
229  left(p) = nil_;
230  else
231  right(p) = nil_;
232  parent(x) = nil_;
233 
234  // update all the counts up to the root
235  updateAncestry(p);
236  }
237  }
238 
241  template <typename T>
243  assert(x != nil_);
244 
245  auto wLeft = cum(left(x));
246  auto wMiddle = wLeft + weight(x);
247 
248  // if the value is in the left third of the split then recurse
249  // on the left subtree
250  if (val < wLeft) return findInterval(left(x), val);
251 
252  // if the value is in the middle of the split then we have
253  // found the node to return
254  else if (val < wMiddle || right(x) == nil_)
255  return x;
256 
257  // otherwise, recurse on the right half, but note that we have
258  // just pruned cum(left(x)) and weight(x) from the search
259  // interval and since subtrees dont know their offset we have
260  // to notify them by reducing the search value
261  else
262  return findInterval(right(x), val - wMiddle);
263  }
264 
268  if (x == nil_) return 0;
269 
270  double err = cum(x) - (cum(left(x)) + weight(x) + cum(right(x)));
271  err *= err;
272 
273  err = std::max(err, validateNode(left(x)));
274  err = std::max(err, validateNode(right(x)));
275  return err;
276  }
277 
278  public:
281  void insert(NodeRef x) {
282  parent(x) = nil_;
283  left(x) = nil_;
284  right(x) = nil_;
285  count(x) = 1;
286  cum(x) = weight(x);
287 
288  if (root_ == nil_)
289  root_ = x;
290  else
291  insert(root_, x);
292  }
293 
295  void remove(NodeRef x) {
296  // if x is a leaf, then simply remove it
297  if (isLeaf(x)) removeLeaf(x);
298 
299  // if x is not a leaf then we need to swap it out with a leaf
300  // so that we can remove it without rebalancing the tree
301  else {
302  // find the deepest leaf in the tree to replace x
304 
305  // swap x with r
306  swap(x, r);
307 
308  // if x was the root we need to swap the pointer to r
309  if (root_ == x) root_ = r;
310 
311  // update the ancestry of r up to the root
312  updateAncestry(r);
313 
314  // and now remove x
315  removeLeaf(x);
316  }
317  }
318 
321  double validateTree() {
322  if (root_ == nil_)
323  return 0;
324  else
325  return std::sqrt(validateNode(root_));
326  }
327 
329  void generateDepthProfile(std::map<int, std::list<NodeRef> >& depthMap) {
330  if (root_ == nil_) return;
331 
332  depthMap[0].push_back(root_);
333  for (int i = 1; true; i++) {
334  if (depthMap[i - 1].size() < 1) break;
335  for (NodeRef node : depthMap[i - 1]) {
336  if (left(node) != nil_) depthMap[i].push_back(left(node));
337  if (right(node) != nil_) depthMap[i].push_back(right(node));
338  }
339  }
340  }
341 
344  template <typename T>
346  if (root_ == nil_) return root_;
347 
348  return findInterval(root_, val * cum(root_));
349  }
350 };
351 
352 } //< namespace btps
353 } //< namespace mpblocks
354 
355 #endif // TREE_H_
NodeOps ops_
Definition: tree.h:52
NodeRef findInterval(NodeRef x, T val)
given $ val [0,1] $, sample a node from the weighted distribution over the subtree rooted at x ...
Definition: tree.h:242
NodeRef root()
return the root node in the tree
Definition: tree.h:94
void insert(NodeRef p, NodeRef x)
recursively insert x in the subtree rooted at p
Definition: tree.h:117
void removeLeaf(NodeRef x)
remove a leaf node from the tree
Definition: tree.h:213
auto sum() -> decltype(ops_.CumulativeWeight(root_))
return the total weight of all nodes in the tree
Definition: tree.h:97
void swap(NodeRef a, NodeRef b)
swaps everything but the weight and cumulative weight
Definition: tree.h:167
auto weight(NodeRef N) -> decltype(ops_.Weight(N))
Definition: tree.h:66
double validateTree()
return the largest difference between cum(x) and the computed value over all nodes in the tree ...
Definition: tree.h:321
size_t size()
return the count of nodes in the tree
Definition: tree.h:105
implements a binary tree of partial sums for sampling from discrete distributions with arbitrary weig...
Definition: tree.h:44
auto count(NodeRef N) -> decltype(ops_.Count(N))
Definition: tree.h:58
Traits::NodeOps NodeOps
Definition: tree.h:47
bool isLeaf(NodeRef x)
true if x is a leaf node (i.e. both children are Nil)
Definition: tree.h:164
void update(NodeRef x)
update a node's count and cumulative weight by summing the count/weight of itself with the cumulative...
Definition: tree.h:111
NodeRef findInterval(T val)
given $ val [0,1] $, sample a node from the weighted distribution
Definition: tree.h:345
NodeRef deepestLeaf(NodeRef p)
find the deepest leaf in a subtree
Definition: tree.h:155
NodeRef & right(NodeRef N)
Definition: tree.h:56
NodeRef nil_
Definition: tree.h:50
NodeRef root_
Definition: tree.h:51
void updateAncestry(NodeRef p)
walk the tree along the parent path from p to the root and update all nodes on that path ...
Definition: tree.h:142
void generateDepthProfile(std::map< int, std::list< NodeRef > > &depthMap)
fill depthMap with a list of nodes at each depth
Definition: tree.h:329
void insert(NodeRef x)
insert a single node into the tree, weight(x) must be set, all other fields are initialized by the tr...
Definition: tree.h:281
Traits::NodeRef NodeRef
Definition: tree.h:46
Tree(NodeRef Nil)
initialize a tree using Nil as the sentinal object
Definition: tree.h:85
NodeRef & left(NodeRef N)
Definition: tree.h:55
NodeRef & parent(NodeRef N)
Definition: tree.h:54
void clear()
clear out the tree (does not free nodes)
Definition: tree.h:102
auto cum(NodeRef N) -> decltype(ops_.CumulativeWeight(N))
Definition: tree.h:62
double validateNode(NodeRef x)
return the larger of the error of cum(x) or the error of it's children
Definition: tree.h:267