7 changed files with 373 additions and 41 deletions
@ -1,22 +1,103 @@
|
||||
use ndarray::Array; |
||||
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray2}; |
||||
use ndarray::prelude::*; |
||||
|
||||
use ndarray::{Array, Array2, ArrayView2}; |
||||
use ndrustfft::{ndfft_r2c, ndifft_r2c, Complex, R2cFftHandler}; |
||||
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2}; |
||||
use pyo3::prelude::*; |
||||
|
||||
fn triangular_taper(taper_size: usize, plateau: usize) -> Array<f32, _> { |
||||
assert!(plateau < taper_size, "Plateau cannot be larger than size."); |
||||
assert!(taper_size as f32 % 2.0 != 0., "Size has to be even"); |
||||
use rayon::prelude::*; |
||||
|
||||
use std::cmp; |
||||
|
||||
fn triangular_taper(window_size: usize, plateau: usize) -> Array2<f32> { |
||||
assert!(plateau < window_size, "Plateau cannot be larger than size."); |
||||
assert!(window_size as f32 % 2.0 == 0., "Size has to be even"); |
||||
|
||||
let ramp_size = (window_size - plateau) / 2; |
||||
let ramp = Array::<f32, _>::linspace(0., 1., ramp_size); |
||||
let mut taper = Array::<f32, _>::ones(window_size); |
||||
|
||||
taper.slice_mut(s![..ramp_size]).assign(&ramp); |
||||
taper |
||||
.slice_mut(s![window_size - ramp_size..]) |
||||
.assign(&ramp.slice(s![..;-1])); |
||||
let window = &taper * &taper.slice(s![.., NewAxis]); |
||||
window |
||||
} |
||||
|
||||
fn goldstein_filter( |
||||
data: ArrayView2<f32>, |
||||
window_size: usize, |
||||
overlap: usize, |
||||
exponent: f32, |
||||
) -> Array2<f32> { |
||||
let window_stride = window_size - overlap; |
||||
let window_non_overlap = window_size - 2 * overlap; |
||||
let npx_x = data.nrows(); |
||||
let npx_y = data.ncols(); |
||||
|
||||
let nx = npx_x / window_stride; |
||||
let ny = npx_y / window_stride; |
||||
|
||||
let mut filtered_data = Array2::<f32>::zeros(data.raw_dim()); |
||||
let taper = triangular_taper(window_size, window_non_overlap); |
||||
|
||||
for ix in 0..nx { |
||||
let px_x_beg = cmp::min(ix * window_stride, npx_x); |
||||
let px_x_end = cmp::min(px_x_beg + window_size, npx_x); |
||||
for ny in 0..ny { |
||||
let px_y_beg = cmp::min(ny * window_stride, npx_y); |
||||
let px_y_end = cmp::min(px_y_beg + window_size, npx_y); |
||||
|
||||
let window_data = data.slice(s![px_x_beg..px_x_end, px_y_beg..px_y_end]); |
||||
let window_shape = [window_data.shape()[0], window_data.shape()[1]]; |
||||
let taper_window = taper.slice(s![..window_shape[0], ..window_shape[1]]); |
||||
|
||||
let mut handler = R2cFftHandler::<f32>::new(window_shape[0]); |
||||
let mut window_fft = |
||||
Array2::<Complex<f32>>::default((window_shape[0] / 2 + 1, window_shape[1])); |
||||
let mut window_filtered = Array2::<f32>::default(window_data.raw_dim()); |
||||
|
||||
ndfft_r2c(&window_data, &mut window_fft, &mut handler, 0); |
||||
window_fft.mapv_inplace(|px| px.scale(px.norm().powf(exponent))); |
||||
ndifft_r2c(&window_fft, &mut window_filtered, &mut handler, 0); |
||||
|
||||
window_filtered *= &taper_window; |
||||
|
||||
let mut filtered_slice = |
||||
filtered_data.slice_mut(s![px_x_beg..px_x_end, px_y_beg..px_y_end]); |
||||
filtered_slice += &window_filtered; |
||||
} |
||||
} |
||||
filtered_data |
||||
} |
||||
|
||||
/// A Python module implemented in Rust.
|
||||
#[pymodule] |
||||
fn rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> { |
||||
#[pyfn(m)] |
||||
fn rust_func<'py>( |
||||
#[pyo3(name = "triangular_taper")] |
||||
fn triangular_taper_wrapper<'py>( |
||||
py: Python<'py>, |
||||
window_size: usize, |
||||
plateau: usize, |
||||
) -> PyResult<&'py PyArray2<f32>> { |
||||
let taper = triangular_taper(window_size, plateau).into_pyarray(py); |
||||
Ok(taper) |
||||
} |
||||
|
||||
#[pyfn(m)] |
||||
#[pyo3(name = "goldstein_filter")] |
||||
fn goldstein_filter_wrapper<'py>( |
||||
py: Python<'py>, |
||||
data: PyReadonlyArray2<'py, f64>, |
||||
) -> PyResult<&'py PyArray1<f64>> { |
||||
let x = vec![1., 2., 3.].into_pyarray(py); |
||||
Ok(x) |
||||
data: PyReadonlyArray2<'py, f32>, |
||||
window_size: usize, |
||||
overlap: usize, |
||||
exponent: f32, |
||||
) -> PyResult<&'py PyArray2<f32>> { |
||||
let data_array = data.as_array(); |
||||
let result = goldstein_filter(data_array, window_size, overlap, exponent).into_pyarray(py); |
||||
Ok(result) |
||||
} |
||||
Ok(()) |
||||
} |
||||
|
Loading…
Reference in new issue