summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJSDurand <mmemmew@gmail.com>2023-03-11 16:21:20 +0800
committerJSDurand <mmemmew@gmail.com>2023-03-11 16:21:20 +0800
commit89e82439186c6dfa48a73bddc0908a494fbd4394 (patch)
tree595b798be4adc66d202850303b0041b4de7a2ac7
parent45394c7cc8f6191e88edd6280b308ae69f588376 (diff)
optimize the implementation
Now the double for loop is eliminated and the time complexity is indeed O(n log(n)). Also added (micro)-benchmarks to roughly confirm the time complexity is as predicted.
-rw-r--r--benches/bench_sm.rs103
-rw-r--r--src/lib.rs294
-rw-r--r--src/main.rs16
3 files changed, 413 insertions, 0 deletions
diff --git a/benches/bench_sm.rs b/benches/bench_sm.rs
new file mode 100644
index 0000000..736397a
--- /dev/null
+++ b/benches/bench_sm.rs
@@ -0,0 +1,103 @@
+//! This file benchmarks the `sm` function.
+
+use criterion::{black_box, criterion_group, criterion_main, Criterion};
+
+use rand::{
+ distributions::{Distribution, Uniform},
+ thread_rng,
+};
+
+use sort::sm;
+
+fn bench_10(c: &mut Criterion) {
+ let mut rng = thread_rng();
+
+ let input = {
+ let uniform = Uniform::new(-100.0f32, 100.0f32);
+ uniform.sample_iter(&mut rng).take(10).collect::<Vec<_>>()
+ };
+
+ assert_eq!(input.len(), 10);
+
+ c.bench_function("sm", |b| {
+ b.iter(|| {
+ let mut count = 0;
+ black_box(sm(input.as_slice(), &mut count))
+ })
+ });
+}
+
+fn bench_20(c: &mut Criterion) {
+ let mut rng = thread_rng();
+
+ let input = {
+ let uniform = Uniform::new(-100.0f32, 100.0f32);
+ uniform.sample_iter(&mut rng).take(20).collect::<Vec<_>>()
+ };
+
+ assert_eq!(input.len(), 20);
+
+ c.bench_function("sm", |b| {
+ b.iter(|| {
+ let mut count = 0;
+ black_box(sm(input.as_slice(), &mut count))
+ })
+ });
+}
+
+fn bench_30(c: &mut Criterion) {
+ let mut rng = thread_rng();
+
+ let input = {
+ let uniform = Uniform::new(-100.0f32, 100.0f32);
+ uniform.sample_iter(&mut rng).take(30).collect::<Vec<_>>()
+ };
+
+ assert_eq!(input.len(), 30);
+
+ c.bench_function("sm", |b| {
+ b.iter(|| {
+ let mut count = 0;
+ black_box(sm(input.as_slice(), &mut count))
+ })
+ });
+}
+
+fn bench_40(c: &mut Criterion) {
+ let mut rng = thread_rng();
+
+ let input = {
+ let uniform = Uniform::new(-100.0f32, 100.0f32);
+ uniform.sample_iter(&mut rng).take(40).collect::<Vec<_>>()
+ };
+
+ assert_eq!(input.len(), 40);
+
+ c.bench_function("sm", |b| {
+ b.iter(|| {
+ let mut count = 0;
+ black_box(sm(input.as_slice(), &mut count))
+ })
+ });
+}
+
+fn bench_50(c: &mut Criterion) {
+ let mut rng = thread_rng();
+
+ let input = {
+ let uniform = Uniform::new(-100.0f32, 100.0f32);
+ uniform.sample_iter(&mut rng).take(50).collect::<Vec<_>>()
+ };
+
+ assert_eq!(input.len(), 50);
+
+ c.bench_function("sm", |b| {
+ b.iter(|| {
+ let mut count = 0;
+ black_box(sm(input.as_slice(), &mut count))
+ })
+ });
+}
+
+criterion_group!(benches, bench_10, bench_20, bench_30, bench_40, bench_50);
+criterion_main!(benches);
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..c474398
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,294 @@
+//! This file implements an optimized version of the sorting algorithm
+//! deduced in the assignment.
+
+use core::{
+ borrow::BorrowMut,
+ cmp::{Ordering, PartialOrd},
+};
+
+#[cfg(test)]
+use rand::{
+ distributions::{Distribution, Uniform},
+ thread_rng,
+};
+
+use std::collections::HashSet;
+
+/// Adjacency set representation
+type Graph = Vec<HashSet<usize>>;
+
+fn default_graph(n: usize) -> Graph {
+ std::iter::repeat_with(Default::default).take(n).collect()
+}
+
+fn add_edge(graph: &mut Graph, source: usize, target: usize) {
+ let len = graph.len();
+
+ if source >= len {
+ panic!("source = {source} >= len = {len}");
+ }
+
+ if target >= len {
+ panic!("target = {target} >= len = {len}");
+ }
+
+ if !graph.get(source).unwrap().contains(&target) {
+ graph.get_mut(source).unwrap().insert(target);
+ }
+
+ if !graph.get(target).unwrap().contains(&source) {
+ graph.get_mut(target).unwrap().insert(source);
+ }
+}
+
+fn sm1<T: PartialOrd>(
+ a: &[T],
+ global_indices: &[usize],
+ graph: &mut Graph,
+ count: &mut usize,
+) -> usize {
+ if a.is_empty() {
+ panic!("empty vector!");
+ }
+
+ if global_indices.is_empty() {
+ panic!("empty input");
+ }
+
+ let n = global_indices.len();
+
+ assert!(n <= a.len());
+
+ let mut indices: Vec<usize> = (0..n).collect();
+
+ let mut upper_bound = n;
+
+ while upper_bound > 1 {
+ for i in 0..(upper_bound.div_euclid(2)) {
+ *count += 1;
+
+ let x = *indices.get(2 * i).unwrap();
+ let y = *indices.get(2 * i + 1).unwrap();
+
+ let ix = *global_indices.get(x).unwrap();
+ let iy = *global_indices.get(y).unwrap();
+
+ let ax = a.get(ix).unwrap();
+ let ay = a.get(iy).unwrap();
+
+ add_edge(graph.borrow_mut(), ix, iy);
+
+ match ax.partial_cmp(ay) {
+ Some(Ordering::Less) | Some(Ordering::Equal) => {
+ *indices.get_mut(i).unwrap() = x;
+ }
+ Some(_) => {
+ *indices.get_mut(i).unwrap() = y;
+ }
+ None => {
+ // We perform one more comparison to make sure
+ // something that is not comparable, like
+ // `f32::NaN`, is moved to the end.
+ *count += 1;
+
+ if ax.partial_cmp(ax).is_some() {
+ *indices.get_mut(i).unwrap() = x;
+ } else {
+ *indices.get_mut(i).unwrap() = y;
+ }
+ }
+ }
+ }
+
+ let offset = upper_bound.rem_euclid(2);
+
+ if offset == 1 {
+ *indices.get_mut(upper_bound.div_euclid(2)).unwrap() =
+ *indices.get(upper_bound - 1).unwrap();
+ }
+
+ upper_bound = upper_bound.div_euclid(2) + offset;
+ }
+
+ *global_indices.get(*indices.first().unwrap()).unwrap()
+}
+
+pub fn sm<T: PartialOrd>(a: &[T], count: &mut usize) -> Vec<usize> {
+ let n = a.len();
+
+ let mut result: Vec<usize> = Vec::with_capacity(n);
+
+ // A "hashset" with trivial hashing
+ let mut added: Vec<bool> = std::iter::repeat(false).take(n).collect();
+
+ let mut sub_indices: Vec<Vec<usize>> = Vec::with_capacity(n);
+
+ sub_indices.push((0..n).collect());
+
+ let mut graph = default_graph(n);
+
+ for i in 0..n {
+ let global_indices = sub_indices.last().unwrap().as_slice();
+
+ let arg_min = sm1(a, global_indices, &mut graph, count);
+
+ assert!(arg_min < n);
+
+ result.push(arg_min);
+
+ if i + 1 == n {
+ // avoid unnecessary work
+ break;
+ }
+
+ *added.get_mut(arg_min).unwrap() = true;
+
+ let mut new_sub_indices_set: HashSet<usize> = HashSet::new();
+
+ let adjacency_set = graph.get(arg_min).unwrap().iter().copied();
+
+ for x in adjacency_set {
+ if !new_sub_indices_set.contains(&x) && added.get(x).copied() == Some(false) {
+ new_sub_indices_set.insert(x);
+ }
+ }
+
+ sub_indices.push(new_sub_indices_set.into_iter().collect());
+ }
+
+ result
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_sm1() {
+ // We test each possible position of the smallest element.
+
+ for n in 1..=10 {
+ let sorted_input: Vec<usize> = (0..n).collect();
+
+ for m in 0..n {
+ let mut input = sorted_input.clone();
+
+ let mut count = 0;
+
+ input.swap(0, m);
+
+ let global_indices: Vec<usize> = (0..n).collect();
+
+ let mut graph: Graph = std::iter::repeat_with(Default::default).take(n).collect();
+
+ let h = sm1(
+ input.as_slice(),
+ global_indices.as_slice(),
+ &mut graph,
+ &mut count,
+ );
+
+ assert_eq!(count, n - 1);
+
+ assert_eq!(h, m);
+
+ // test the adjacency lists
+
+ for k in 0..n {
+ let mut list: Vec<usize> = graph.get(k).unwrap().iter().copied().collect();
+
+ list.sort_unstable(); // not merge_sort, by the way, haha.
+
+ let mut upper = n;
+ let mut answer: Vec<usize> = Vec::with_capacity(n.ilog2() as usize + 1);
+ let mut indices: Vec<usize> = (0..n).collect();
+
+ while upper > 1 {
+ for i in 0..(upper.div_euclid(2)) {
+ let x = *indices.get(2 * i).unwrap();
+ let y = *indices.get(2 * i + 1).unwrap();
+
+ if x == k {
+ answer.push(y);
+ } else if y == k {
+ answer.push(x);
+ }
+
+ let ax = match x {
+ 0 => m,
+ z if z == m => 0,
+ z => z,
+ };
+
+ let ay = match y {
+ 0 => m,
+ z if z == m => 0,
+ z => z,
+ };
+
+ match ax.partial_cmp(&ay) {
+ Some(Ordering::Less) | Some(Ordering::Equal) => {
+ *indices.get_mut(i).unwrap() = x;
+ }
+ Some(_) => {
+ *indices.get_mut(i).unwrap() = y;
+ }
+ None => {
+ if ax.partial_cmp(&ax).is_some() {
+ *indices.get_mut(i).unwrap() = x;
+ } else {
+ *indices.get_mut(i).unwrap() = y;
+ }
+ }
+ }
+ }
+
+ let offset = upper.rem_euclid(2);
+
+ if offset == 1 {
+ *indices.get_mut(upper.div_euclid(2)).unwrap() =
+ *indices.get(upper - 1).unwrap();
+ }
+
+ upper = upper.div_euclid(2) + offset;
+ }
+
+ answer.sort_unstable();
+
+ if list != answer {
+ panic!("k = {k}, n = {n}");
+ }
+ }
+ }
+ }
+ }
+
+ #[test]
+ fn test_sm_() {
+ let mut rng = thread_rng();
+
+ for i in 1..=5 {
+ let n = 10 * i;
+ let input = {
+ let uniform = Uniform::new(-100.0f32, 100.0f32);
+ uniform.sample_iter(&mut rng).take(n).collect::<Vec<_>>()
+ };
+
+ let mut count = 0;
+ let _sorted = sm(input.as_slice(), &mut count);
+
+ // println!(
+ // "sorted = {:?}",
+ // sorted
+ // .into_iter()
+ // .map(|x| *input.get(x).unwrap())
+ // .collect::<Vec<_>>()
+ // );
+
+ println!(
+ "n = {n}, count = {count}, nlog(n) = {}",
+ n * ((n as f32).log2() as usize)
+ );
+ }
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 64709b1..f223aeb 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -55,6 +55,8 @@ fn sm1<T: PartialOrd>(a: Vec<T>, count: &mut usize) -> (usize, Graph) {
*indices.get_mut(i).unwrap() = y;
}
None => {
+ *count += 1;
+
if a.get(x).unwrap().partial_cmp(a.get(x).unwrap()).is_some() {
*indices.get_mut(i).unwrap() = x;
} else {
@@ -181,4 +183,18 @@ fn main() {
.collect();
println!("sort indices = {sort:?}\nsort result: {sort_result:?}\ncount = {count}");
+
+ // Now use the optimized version
+
+ count = 0;
+
+ let sort = sort::sm(inputs.as_slice(), &mut count);
+
+ let sort_result: Vec<f32> = sort
+ .iter()
+ .copied()
+ .map(|n| inputs.get(n).copied().unwrap())
+ .collect();
+
+ println!("sort indices = {sort:?}\nsort result: {sort_result:?}\ncount = {count}");
}