From 89e82439186c6dfa48a73bddc0908a494fbd4394 Mon Sep 17 00:00:00 2001 From: JSDurand Date: Sat, 11 Mar 2023 16:21:20 +0800 Subject: optimize the implementation Now the double for loop is eliminated and the time complexity is indeed O(n log(n)). Also added (micro)-benchmarks to roughly confirm the time complexity is as predicted. --- benches/bench_sm.rs | 103 ++++++++++++++++++ src/lib.rs | 294 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 16 +++ 3 files changed, 413 insertions(+) create mode 100644 benches/bench_sm.rs create mode 100644 src/lib.rs diff --git a/benches/bench_sm.rs b/benches/bench_sm.rs new file mode 100644 index 0000000..736397a --- /dev/null +++ b/benches/bench_sm.rs @@ -0,0 +1,103 @@ +//! This file benchmarks the `sm` function. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use rand::{ + distributions::{Distribution, Uniform}, + thread_rng, +}; + +use sort::sm; + +fn bench_10(c: &mut Criterion) { + let mut rng = thread_rng(); + + let input = { + let uniform = Uniform::new(-100.0f32, 100.0f32); + uniform.sample_iter(&mut rng).take(10).collect::>() + }; + + assert_eq!(input.len(), 10); + + c.bench_function("sm", |b| { + b.iter(|| { + let mut count = 0; + black_box(sm(input.as_slice(), &mut count)) + }) + }); +} + +fn bench_20(c: &mut Criterion) { + let mut rng = thread_rng(); + + let input = { + let uniform = Uniform::new(-100.0f32, 100.0f32); + uniform.sample_iter(&mut rng).take(20).collect::>() + }; + + assert_eq!(input.len(), 20); + + c.bench_function("sm", |b| { + b.iter(|| { + let mut count = 0; + black_box(sm(input.as_slice(), &mut count)) + }) + }); +} + +fn bench_30(c: &mut Criterion) { + let mut rng = thread_rng(); + + let input = { + let uniform = Uniform::new(-100.0f32, 100.0f32); + uniform.sample_iter(&mut rng).take(30).collect::>() + }; + + assert_eq!(input.len(), 30); + + c.bench_function("sm", |b| { + b.iter(|| { + let mut count = 0; + black_box(sm(input.as_slice(), &mut count)) + }) + }); +} + +fn bench_40(c: &mut Criterion) { + let mut rng = thread_rng(); + + let input = { + let uniform = Uniform::new(-100.0f32, 100.0f32); + uniform.sample_iter(&mut rng).take(40).collect::>() + }; + + assert_eq!(input.len(), 40); + + c.bench_function("sm", |b| { + b.iter(|| { + let mut count = 0; + black_box(sm(input.as_slice(), &mut count)) + }) + }); +} + +fn bench_50(c: &mut Criterion) { + let mut rng = thread_rng(); + + let input = { + let uniform = Uniform::new(-100.0f32, 100.0f32); + uniform.sample_iter(&mut rng).take(50).collect::>() + }; + + assert_eq!(input.len(), 50); + + c.bench_function("sm", |b| { + b.iter(|| { + let mut count = 0; + black_box(sm(input.as_slice(), &mut count)) + }) + }); +} + +criterion_group!(benches, bench_10, bench_20, bench_30, bench_40, bench_50); +criterion_main!(benches); diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c474398 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,294 @@ +//! This file implements an optimized version of the sorting algorithm +//! deduced in the assignment. + +use core::{ + borrow::BorrowMut, + cmp::{Ordering, PartialOrd}, +}; + +#[cfg(test)] +use rand::{ + distributions::{Distribution, Uniform}, + thread_rng, +}; + +use std::collections::HashSet; + +/// Adjacency set representation +type Graph = Vec>; + +fn default_graph(n: usize) -> Graph { + std::iter::repeat_with(Default::default).take(n).collect() +} + +fn add_edge(graph: &mut Graph, source: usize, target: usize) { + let len = graph.len(); + + if source >= len { + panic!("source = {source} >= len = {len}"); + } + + if target >= len { + panic!("target = {target} >= len = {len}"); + } + + if !graph.get(source).unwrap().contains(&target) { + graph.get_mut(source).unwrap().insert(target); + } + + if !graph.get(target).unwrap().contains(&source) { + graph.get_mut(target).unwrap().insert(source); + } +} + +fn sm1( + a: &[T], + global_indices: &[usize], + graph: &mut Graph, + count: &mut usize, +) -> usize { + if a.is_empty() { + panic!("empty vector!"); + } + + if global_indices.is_empty() { + panic!("empty input"); + } + + let n = global_indices.len(); + + assert!(n <= a.len()); + + let mut indices: Vec = (0..n).collect(); + + let mut upper_bound = n; + + while upper_bound > 1 { + for i in 0..(upper_bound.div_euclid(2)) { + *count += 1; + + let x = *indices.get(2 * i).unwrap(); + let y = *indices.get(2 * i + 1).unwrap(); + + let ix = *global_indices.get(x).unwrap(); + let iy = *global_indices.get(y).unwrap(); + + let ax = a.get(ix).unwrap(); + let ay = a.get(iy).unwrap(); + + add_edge(graph.borrow_mut(), ix, iy); + + match ax.partial_cmp(ay) { + Some(Ordering::Less) | Some(Ordering::Equal) => { + *indices.get_mut(i).unwrap() = x; + } + Some(_) => { + *indices.get_mut(i).unwrap() = y; + } + None => { + // We perform one more comparison to make sure + // something that is not comparable, like + // `f32::NaN`, is moved to the end. + *count += 1; + + if ax.partial_cmp(ax).is_some() { + *indices.get_mut(i).unwrap() = x; + } else { + *indices.get_mut(i).unwrap() = y; + } + } + } + } + + let offset = upper_bound.rem_euclid(2); + + if offset == 1 { + *indices.get_mut(upper_bound.div_euclid(2)).unwrap() = + *indices.get(upper_bound - 1).unwrap(); + } + + upper_bound = upper_bound.div_euclid(2) + offset; + } + + *global_indices.get(*indices.first().unwrap()).unwrap() +} + +pub fn sm(a: &[T], count: &mut usize) -> Vec { + let n = a.len(); + + let mut result: Vec = Vec::with_capacity(n); + + // A "hashset" with trivial hashing + let mut added: Vec = std::iter::repeat(false).take(n).collect(); + + let mut sub_indices: Vec> = Vec::with_capacity(n); + + sub_indices.push((0..n).collect()); + + let mut graph = default_graph(n); + + for i in 0..n { + let global_indices = sub_indices.last().unwrap().as_slice(); + + let arg_min = sm1(a, global_indices, &mut graph, count); + + assert!(arg_min < n); + + result.push(arg_min); + + if i + 1 == n { + // avoid unnecessary work + break; + } + + *added.get_mut(arg_min).unwrap() = true; + + let mut new_sub_indices_set: HashSet = HashSet::new(); + + let adjacency_set = graph.get(arg_min).unwrap().iter().copied(); + + for x in adjacency_set { + if !new_sub_indices_set.contains(&x) && added.get(x).copied() == Some(false) { + new_sub_indices_set.insert(x); + } + } + + sub_indices.push(new_sub_indices_set.into_iter().collect()); + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sm1() { + // We test each possible position of the smallest element. + + for n in 1..=10 { + let sorted_input: Vec = (0..n).collect(); + + for m in 0..n { + let mut input = sorted_input.clone(); + + let mut count = 0; + + input.swap(0, m); + + let global_indices: Vec = (0..n).collect(); + + let mut graph: Graph = std::iter::repeat_with(Default::default).take(n).collect(); + + let h = sm1( + input.as_slice(), + global_indices.as_slice(), + &mut graph, + &mut count, + ); + + assert_eq!(count, n - 1); + + assert_eq!(h, m); + + // test the adjacency lists + + for k in 0..n { + let mut list: Vec = graph.get(k).unwrap().iter().copied().collect(); + + list.sort_unstable(); // not merge_sort, by the way, haha. + + let mut upper = n; + let mut answer: Vec = Vec::with_capacity(n.ilog2() as usize + 1); + let mut indices: Vec = (0..n).collect(); + + while upper > 1 { + for i in 0..(upper.div_euclid(2)) { + let x = *indices.get(2 * i).unwrap(); + let y = *indices.get(2 * i + 1).unwrap(); + + if x == k { + answer.push(y); + } else if y == k { + answer.push(x); + } + + let ax = match x { + 0 => m, + z if z == m => 0, + z => z, + }; + + let ay = match y { + 0 => m, + z if z == m => 0, + z => z, + }; + + match ax.partial_cmp(&ay) { + Some(Ordering::Less) | Some(Ordering::Equal) => { + *indices.get_mut(i).unwrap() = x; + } + Some(_) => { + *indices.get_mut(i).unwrap() = y; + } + None => { + if ax.partial_cmp(&ax).is_some() { + *indices.get_mut(i).unwrap() = x; + } else { + *indices.get_mut(i).unwrap() = y; + } + } + } + } + + let offset = upper.rem_euclid(2); + + if offset == 1 { + *indices.get_mut(upper.div_euclid(2)).unwrap() = + *indices.get(upper - 1).unwrap(); + } + + upper = upper.div_euclid(2) + offset; + } + + answer.sort_unstable(); + + if list != answer { + panic!("k = {k}, n = {n}"); + } + } + } + } + } + + #[test] + fn test_sm_() { + let mut rng = thread_rng(); + + for i in 1..=5 { + let n = 10 * i; + let input = { + let uniform = Uniform::new(-100.0f32, 100.0f32); + uniform.sample_iter(&mut rng).take(n).collect::>() + }; + + let mut count = 0; + let _sorted = sm(input.as_slice(), &mut count); + + // println!( + // "sorted = {:?}", + // sorted + // .into_iter() + // .map(|x| *input.get(x).unwrap()) + // .collect::>() + // ); + + println!( + "n = {n}, count = {count}, nlog(n) = {}", + n * ((n as f32).log2() as usize) + ); + } + } +} diff --git a/src/main.rs b/src/main.rs index 64709b1..f223aeb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -55,6 +55,8 @@ fn sm1(a: Vec, count: &mut usize) -> (usize, Graph) { *indices.get_mut(i).unwrap() = y; } None => { + *count += 1; + if a.get(x).unwrap().partial_cmp(a.get(x).unwrap()).is_some() { *indices.get_mut(i).unwrap() = x; } else { @@ -181,4 +183,18 @@ fn main() { .collect(); println!("sort indices = {sort:?}\nsort result: {sort_result:?}\ncount = {count}"); + + // Now use the optimized version + + count = 0; + + let sort = sort::sm(inputs.as_slice(), &mut count); + + let sort_result: Vec = sort + .iter() + .copied() + .map(|n| inputs.get(n).copied().unwrap()) + .collect(); + + println!("sort indices = {sort:?}\nsort result: {sort_result:?}\ncount = {count}"); } -- cgit v1.2.3-18-g5258