adding rust
parent
915dbbefcb
commit
eece87ad77
@ -1,22 +1,103 @@
|
||||
use ndarray::Array;
|
||||
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray2};
|
||||
use ndarray::prelude::*;
|
||||
|
||||
use ndarray::{Array, Array2, ArrayView2};
|
||||
use ndrustfft::{ndfft_r2c, ndifft_r2c, Complex, R2cFftHandler};
|
||||
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
|
||||
use pyo3::prelude::*;
|
||||
|
||||
fn triangular_taper(taper_size: usize, plateau: usize) -> Array<f32, _> {
|
||||
assert!(plateau < taper_size, "Plateau cannot be larger than size.");
|
||||
assert!(taper_size as f32 % 2.0 != 0., "Size has to be even");
|
||||
use rayon::prelude::*;
|
||||
|
||||
use std::cmp;
|
||||
|
||||
fn triangular_taper(window_size: usize, plateau: usize) -> Array2<f32> {
|
||||
assert!(plateau < window_size, "Plateau cannot be larger than size.");
|
||||
assert!(window_size as f32 % 2.0 == 0., "Size has to be even");
|
||||
|
||||
let ramp_size = (window_size - plateau) / 2;
|
||||
let ramp = Array::<f32, _>::linspace(0., 1., ramp_size);
|
||||
let mut taper = Array::<f32, _>::ones(window_size);
|
||||
|
||||
taper.slice_mut(s![..ramp_size]).assign(&ramp);
|
||||
taper
|
||||
.slice_mut(s![window_size - ramp_size..])
|
||||
.assign(&ramp.slice(s![..;-1]));
|
||||
let window = &taper * &taper.slice(s![.., NewAxis]);
|
||||
window
|
||||
}
|
||||
|
||||
fn goldstein_filter(
|
||||
data: ArrayView2<f32>,
|
||||
window_size: usize,
|
||||
overlap: usize,
|
||||
exponent: f32,
|
||||
) -> Array2<f32> {
|
||||
let window_stride = window_size - overlap;
|
||||
let window_non_overlap = window_size - 2 * overlap;
|
||||
let npx_x = data.nrows();
|
||||
let npx_y = data.ncols();
|
||||
|
||||
let nx = npx_x / window_stride;
|
||||
let ny = npx_y / window_stride;
|
||||
|
||||
let mut filtered_data = Array2::<f32>::zeros(data.raw_dim());
|
||||
let taper = triangular_taper(window_size, window_non_overlap);
|
||||
|
||||
for ix in 0..nx {
|
||||
let px_x_beg = cmp::min(ix * window_stride, npx_x);
|
||||
let px_x_end = cmp::min(px_x_beg + window_size, npx_x);
|
||||
for ny in 0..ny {
|
||||
let px_y_beg = cmp::min(ny * window_stride, npx_y);
|
||||
let px_y_end = cmp::min(px_y_beg + window_size, npx_y);
|
||||
|
||||
let window_data = data.slice(s![px_x_beg..px_x_end, px_y_beg..px_y_end]);
|
||||
let window_shape = [window_data.shape()[0], window_data.shape()[1]];
|
||||
let taper_window = taper.slice(s![..window_shape[0], ..window_shape[1]]);
|
||||
|
||||
let mut handler = R2cFftHandler::<f32>::new(window_shape[0]);
|
||||
let mut window_fft =
|
||||
Array2::<Complex<f32>>::default((window_shape[0] / 2 + 1, window_shape[1]));
|
||||
let mut window_filtered = Array2::<f32>::default(window_data.raw_dim());
|
||||
|
||||
ndfft_r2c(&window_data, &mut window_fft, &mut handler, 0);
|
||||
window_fft.mapv_inplace(|px| px.scale(px.norm().powf(exponent)));
|
||||
ndifft_r2c(&window_fft, &mut window_filtered, &mut handler, 0);
|
||||
|
||||
window_filtered *= &taper_window;
|
||||
|
||||
let mut filtered_slice =
|
||||
filtered_data.slice_mut(s![px_x_beg..px_x_end, px_y_beg..px_y_end]);
|
||||
filtered_slice += &window_filtered;
|
||||
}
|
||||
}
|
||||
filtered_data
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust.
|
||||
#[pymodule]
|
||||
fn rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
#[pyfn(m)]
|
||||
fn rust_func<'py>(
|
||||
#[pyo3(name = "triangular_taper")]
|
||||
fn triangular_taper_wrapper<'py>(
|
||||
py: Python<'py>,
|
||||
window_size: usize,
|
||||
plateau: usize,
|
||||
) -> PyResult<&'py PyArray2<f32>> {
|
||||
let taper = triangular_taper(window_size, plateau).into_pyarray(py);
|
||||
Ok(taper)
|
||||
}
|
||||
|
||||
#[pyfn(m)]
|
||||
#[pyo3(name = "goldstein_filter")]
|
||||
fn goldstein_filter_wrapper<'py>(
|
||||
py: Python<'py>,
|
||||
data: PyReadonlyArray2<'py, f64>,
|
||||
) -> PyResult<&'py PyArray1<f64>> {
|
||||
let x = vec![1., 2., 3.].into_pyarray(py);
|
||||
Ok(x)
|
||||
data: PyReadonlyArray2<'py, f32>,
|
||||
window_size: usize,
|
||||
overlap: usize,
|
||||
exponent: f32,
|
||||
) -> PyResult<&'py PyArray2<f32>> {
|
||||
let data_array = data.as_array();
|
||||
let result = goldstein_filter(data_array, window_size, overlap, exponent).into_pyarray(py);
|
||||
Ok(result)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
Loading…
Reference in New Issue