Created
December 8, 2022 21:49
-
-
Save mbhall88/80cd054410f960cea0c451b8b0edae71 to your computer and use it in GitHub Desktop.
Argsort and sort by indices in Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use rayon::prelude::*; | |
/// Returns the indices that would sort a vector in parallel | |
fn par_argsort<T: Ord + Sync>(data: &[T]) -> Vec<usize> { | |
let mut indices = (0..data.len()).collect::<Vec<_>>(); | |
indices.par_sort_unstable_by_key(|&i| &data[i]); | |
indices | |
} | |
/// Sort the given vector by the given indices in parallel | |
fn sort_by_indices<T: Ord + Sync>(v: &mut Vec<T>, indices: &[usize]) { | |
// Create a vector of indices to be sorted | |
let mut indices_to_sort: Vec<_> = (0..v.len()).collect(); | |
// Use `par_sort_unstable_by_key` to sort the indices in parallel | |
indices_to_sort.par_sort_unstable_by_key(|&i| indices[i]); | |
// Use `par_sort_unstable_by_key` to sort the elements of `v` by the sorted indices | |
v.par_sort_unstable_by_key(|x| indices_to_sort.iter().position(|&i| i == x).unwrap()); | |
} | |
/// Returns the indices that would sort a vector | |
fn argsort<T: Ord>(data: &[T]) -> Vec<usize> { | |
let mut indices: Vec<_> = data.iter().enumerate().collect(); | |
// Use `sort_by_key` to sort the indices by the value of the elements in `data` | |
indices.sort_by_key(|&(_, v)| v); | |
// Extract the indices from the sorted vector | |
indices.into_iter().map(|(i, _)| i).collect() | |
} | |
/// Sort the given vector by the given indices | |
fn sort_by_indices<T: Ord>(v: &mut Vec<T>, indices: &[usize]) { | |
// Create a vector of indices to be sorted | |
let mut indices_to_sort: Vec<_> = (0..v.len()).collect(); | |
// Use `sort_by_key` to sort the indices by their corresponding values in `indices` | |
indices_to_sort.sort_by_key(|&i| indices[i]); | |
// Use `sort_by_cached_key` to sort the elements of `v` by the sorted indices | |
v.sort_by_cached_key(|x| indices_to_sort.iter().position(|&i| i == x).unwrap()); | |
} | |
#[test] | |
fn test_argsort() { | |
let data = vec![3, 1, 2]; | |
let expected = vec![1, 2, 0]; | |
let result = argsort(&data); | |
assert_eq!(result, expected); | |
} | |
#[test] | |
fn test_par_argsort() { | |
let data = vec![3, 1, 2]; | |
let expected = vec![1, 2, 0]; | |
let result = par_argsort(&data); | |
assert_eq!(result, expected); | |
} | |
#[test] | |
fn test_sort_by_indices() { | |
let mut v = vec!["a", "b", "c", "d"]; | |
let i = vec![0, 3, 2, 1]; | |
sort_by_indices(&mut v, &i); | |
assert_eq!(v, &["a", "d", "c", "b"]); | |
} | |
#[test] | |
fn test_par_sort_by_indices() { | |
let mut v = vec!["a", "b", "c", "d"]; | |
let i = vec![0, 3, 2, 1]; | |
par_sort_by_indices(&mut v, &i); | |
assert_eq!(v, &["a", "d", "c", "b"]); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment