summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJSDurand <mmemmew@gmail.com>2023-03-14 23:11:09 +0800
committerJSDurand <mmemmew@gmail.com>2023-03-14 23:11:09 +0800
commitdca341dbb54737e8cc14a41088e23d23f7eb385e (patch)
treebaa16096f3afd64a707591ef863fb96d3855d3a5
parent89e82439186c6dfa48a73bddc0908a494fbd4394 (diff)
Implement the correct solution
Now the correct solution, as suggested by the professor, is implemented as the function `smi`. :D
-rw-r--r--src/incremental.rs318
-rw-r--r--src/lib.rs27
2 files changed, 336 insertions, 9 deletions
diff --git a/src/incremental.rs b/src/incremental.rs
new file mode 100644
index 0000000..361da7c
--- /dev/null
+++ b/src/incremental.rs
@@ -0,0 +1,318 @@
+//! This file implements the suggested solution from the professor,
+//! incrementally updating the existing graph, instead of creating new
+//! graphs.
+
+use super::*;
+
+/// We keep a list of graphs, one for each round. Since each node in
+/// each graph has exactly one edge out of it, we can simply represent
+/// the graph as a list of integers.
+type Graph = Vec<usize>;
+
+/// Returns an empty graph, where every node is associated with an
+/// index that is out of bounds.
+fn default_graph(n: usize) -> Graph {
+ std::iter::repeat(n).take(n).collect()
+}
+
+fn add_edge(graph: &mut Graph, source: usize, target: usize) {
+ let len = graph.len();
+
+ if source >= len {
+ panic!("source = {source} out of bound: {len}");
+ }
+
+ if target >= len {
+ panic!("target = {target} out of bound: {len}");
+ }
+
+ *graph.get_mut(source).unwrap() = target;
+ *graph.get_mut(target).unwrap() = source;
+}
+
+// NOTE: This seems to be unnecessary now.
+/// Returns the associated number.
+///
+/// To be more precise, if `node` is compared with a number `n` at
+/// round `round`, then return `Some(round)`, otherwise return `None`.
+///
+/// If either `round` or `node` is out of bounds, this function
+/// panics.
+#[allow(dead_code)]
+fn assoc(graphs: &[Graph], round: usize, node: usize) -> Option<usize> {
+ if let Some(graph) = graphs.get(round) {
+ if let Some(n) = graph.get(node).copied() {
+ if n >= graph.len() {
+ None
+ } else {
+ Some(n)
+ }
+ } else {
+ panic!("node = {node} out of bound = {}", graph.len());
+ }
+ } else {
+ panic!("round = {round} out of bound = {}", graphs.len());
+ }
+}
+
+/// Initial **MIN** algorithm.
+fn smi_init<T: PartialOrd>(a: &[T], graphs: &mut Vec<Graph>, count: &mut usize) -> usize {
+ if a.is_empty() {
+ panic!("invalid empty input");
+ }
+
+ let n = a.len();
+ let logn_ceil = (n as f32).log2().ceil() as usize;
+
+ let mut indices: Vec<_> = (0..n).collect();
+
+ let mut upper = n;
+
+ graphs.clear();
+ graphs
+ .try_reserve(logn_ceil)
+ .unwrap_or_else(|e| panic!("reserving memory fails: {e}"));
+
+ while upper > 1 {
+ let mut graph = default_graph(n);
+
+ let div = upper.div_euclid(2);
+
+ for i in 0..div {
+ *count += 1;
+
+ let x = indices.get(2 * i).unwrap();
+ let y = indices.get(2 * i + 1).unwrap();
+
+ let ax = a.get(*x).unwrap();
+ let ay = a.get(*y).unwrap();
+
+ add_edge(graph.borrow_mut(), *x, *y);
+
+ match ax.partial_cmp(ay) {
+ Some(Ordering::Less) | Some(Ordering::Equal) => {
+ *indices.get_mut(i).unwrap() = *x;
+ }
+ Some(_) => {
+ *indices.get_mut(i).unwrap() = *y;
+ }
+ None => {
+ *count += 1;
+
+ if ax.partial_cmp(ax).is_some() {
+ *indices.get_mut(i).unwrap() = *x;
+ } else {
+ *indices.get_mut(i).unwrap() = *y;
+ }
+ }
+ }
+ }
+
+ let offset = upper.rem_euclid(2);
+
+ if offset == 1 {
+ *indices.get_mut(div).unwrap() = *indices.get(upper - 1).unwrap();
+ }
+
+ upper = div + offset;
+
+ graphs.push(graph);
+ }
+
+ assert_eq!(graphs.len(), logn_ceil);
+
+ *indices.first().unwrap()
+}
+
+fn smi_step<T: PartialOrd>(
+ a: &[T],
+ previous_min_index: usize,
+ added_indices: &HashSet<usize>,
+ graphs: &mut Vec<Graph>,
+ count: &mut usize,
+) -> usize {
+ let n = a.len();
+
+ if previous_min_index >= n {
+ panic!("invalid min index: {previous_min_index} with n = {n}");
+ }
+
+ if graphs.is_empty() {
+ panic!("invalid empty round");
+ }
+
+ let mut to_compare: Option<usize> = None;
+
+ for graph in graphs.iter_mut() {
+ let assoc = {
+ if let Some(assoc) = graph.get(previous_min_index).copied() {
+ if assoc >= graph.len() {
+ continue;
+ } else {
+ assoc
+ }
+ } else {
+ panic!(
+ "prev = {previous_min_index}, graph length = {}",
+ graph.len()
+ );
+ }
+ };
+
+ if let Some(cand) = to_compare {
+ *count += 1;
+
+ add_edge(graph.borrow_mut(), cand, assoc);
+
+ let acand = a.get(cand).unwrap();
+ let a_assoc = a.get(assoc).unwrap();
+
+ let cand_already_added = added_indices.contains(&cand);
+ let assoc_already_added = added_indices.contains(&assoc);
+
+ if cand_already_added && assoc_already_added {
+ to_compare = None;
+ continue;
+ } else if cand_already_added {
+ to_compare = Some(assoc);
+ continue;
+ } else if assoc_already_added {
+ to_compare = Some(cand);
+ continue;
+ }
+
+ match acand.partial_cmp(a_assoc) {
+ Some(Ordering::Less) | Some(Ordering::Equal) => {}
+ _ => {
+ to_compare = Some(assoc);
+ }
+ }
+ } else {
+ to_compare = Some(assoc);
+ }
+ }
+
+ if to_compare.is_none() {
+ let mut unadded: Vec<_> = Vec::new();
+
+ for i in 0..n {
+ if !added_indices.contains(&i) {
+ unadded.push(i);
+ }
+ }
+
+ dbg!(previous_min_index, &unadded);
+
+ return *unadded.first().unwrap();
+ }
+
+ to_compare.unwrap()
+}
+
+pub fn smi<T: PartialOrd>(a: &[T], count: &mut usize) -> Vec<usize> {
+ if a.is_empty() {
+ return Vec::new();
+ }
+
+ let n = a.len();
+
+ let mut graphs: Vec<Graph> = Vec::with_capacity((n.ilog2() as usize) + 1);
+
+ let mut result = Vec::with_capacity(n);
+ let mut added_indices = HashSet::with_capacity(n);
+
+ let min = smi_init(a, &mut graphs, count);
+
+ result.push(min);
+ added_indices.insert(min);
+
+ let mut step_min = min;
+
+ for _ in 1..n {
+ step_min = smi_step(a, step_min, &added_indices, &mut graphs, count);
+
+ result.push(step_min);
+ added_indices.insert(step_min);
+ }
+
+ result
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use rand::{
+ distributions::{Distribution, Uniform},
+ thread_rng,
+ };
+
+ #[test]
+ fn test_smi_init_and_one_step() {
+ let input = vec![1, 2, 3, 0, 9, 2, 3];
+
+ let mut count = 0;
+
+ let mut graphs = Vec::with_capacity((input.len().ilog2() as usize) + 1);
+
+ let min = smi_init(&input, &mut graphs, &mut count);
+
+ let mut added_indices = HashSet::with_capacity(7);
+
+ added_indices.insert(min);
+
+ assert_eq!(min, 3);
+
+ let second_min = smi_step(&input, min, &added_indices, &mut graphs, &mut count);
+
+ assert_eq!(second_min, 0);
+ }
+
+ #[test]
+ fn test_smi_itself() {
+ for i in 1..=5 {
+ let two_i = 1 << i;
+
+ let mut rng = thread_rng();
+
+ let input = {
+ let uniform = Uniform::new(-100i32, 100i32);
+ uniform
+ .sample_iter(&mut rng)
+ .take(10 * two_i)
+ .collect::<Vec<_>>()
+ };
+
+ let n = input.len();
+
+ // println!("input = {input:?}");
+ println!("input length = {n}");
+
+ let nlogn = n * ((n as f32).log2().ceil() as usize);
+
+ let mut count = 0;
+
+ let sort_result = smi(&input, &mut count);
+
+ println!("count = {count}, nlog(n) = {nlogn}");
+ // println!("sort_result = {sort_result:?}");
+
+ let sort_result_numbers: Vec<_> = sort_result
+ .iter()
+ .copied()
+ .map(|x| *input.get(x).unwrap())
+ .collect();
+
+ // println!("sort_result_numbers = {sort_result_numbers:?}");
+
+ let mut answer = input.clone();
+
+ answer.sort_unstable();
+
+ assert_eq!(answer, sort_result_numbers);
+
+ println!("correctly sorted");
+ println!();
+ }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index c474398..ff309d1 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -14,6 +14,10 @@ use rand::{
use std::collections::HashSet;
+pub mod incremental;
+
+pub use incremental::smi;
+
/// Adjacency set representation
type Graph = Vec<HashSet<usize>>;
@@ -275,19 +279,24 @@ mod tests {
};
let mut count = 0;
- let _sorted = sm(input.as_slice(), &mut count);
- // println!(
- // "sorted = {:?}",
- // sorted
- // .into_iter()
- // .map(|x| *input.get(x).unwrap())
- // .collect::<Vec<_>>()
- // );
+ let sorted = sm(input.as_slice(), &mut count);
+
+ let sorted: Vec<_> = sorted.into_iter().map(|x| *input.get(x).unwrap()).collect();
+
+ let mut really_sorted = input;
+
+ // floats are not totally ordered, so there is no default
+ // method to sort a vector of floats.
+ really_sorted.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap());
+
+ assert_eq!(sorted, really_sorted);
+
+ // println!("sorted = {:?}", sorted,);
println!(
"n = {n}, count = {count}, nlog(n) = {}",
- n * ((n as f32).log2() as usize)
+ n * ((n as f32).log2().ceil() as usize)
);
}
}