Browse Source

goldstein: 2d param window

master
Marius Isken 4 weeks ago
parent
commit
0168c8494b
  1. 129
      lightguide/goldstein.py
  2. 246
      src/lib.rs
  3. 33
      tests/test_goldstein.py

129
lightguide/goldstein.py

@ -3,11 +3,8 @@ from __future__ import annotations
import os.path as op
from functools import lru_cache
import matplotlib.pyplot as plt
import numpy as num
from matplotlib.widgets import Slider
from scipy import signal
from matplotlib.colors import Normalize
from scipy.signal import butter, lfilter
from .utils import timeit
@ -36,16 +33,13 @@ def triangular_taper(size: int, plateau: int):
return window * window[:, num.newaxis]
def data_coherency(data):
ntraces = data.shape[0]
coherency = 0.0
for itr in range(ntraces - 1):
coherency += signal.coherence(data[itr], data[itr + 1])[1].mean()
return coherency / (ntraces - 1)
def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.3):
def goldstein_filter(
data,
window_size: int = 32,
overlap: int = 14,
exponent: float = 0.3,
normalize_power: bool = False,
):
if num.log2(window_size) % 1.0 or window_size < 4:
raise ValueError("window_size has to be pow(2) and > 4.")
if overlap > window_size / 2 - 1:
@ -72,14 +66,17 @@ def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.
window_data = data[slice_x, slice_y]
window_fft = num.fft.rfft2(window_data)
weights = num.abs(window_fft) ** exponent
# Optional
# weights = signal.medfilt2d(weights, kernel_size=3)
# weights /= weights.sum()
# window_coherency = data_coherency(window_data)
power_spec = num.abs(window_fft)
# power_spec = signal.medfilt2d(power_spec, kernel_size=3)
if normalize_power:
power_spec /= power_spec.max()
weights = power_spec ** exponent
window_fft *= weights
# window_fft /= weights.sum()
window_flt = num.fft.irfft2(window_fft)
# window_flt /= weights.max()
taper_this = taper[: window_flt.shape[0], : window_flt.shape[1]]
window_flt *= taper_this
filtered_data[
@ -88,99 +85,3 @@ def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.
] += window_flt
return filtered_data
def normalize_diff(data, filtered_data):
return Normalize()(data) - Normalize()(filtered_data)
def plot_goldstein(data):
def r(data):
v = num.std(data) * 2
return -v, v
from .rust import goldstein_filter as goldstein_filter_rust
window_size = 64
overlap = 30
exponent = 0.5
adaptive_weights = False
imshow_kwargs = dict(aspect=5, cmap="viridis", vmin=-1.0, vmax=1.0)
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharex=True, sharey=True)
ax1.imshow(data, **imshow_kwargs)
image_python = ax2.imshow(data, **imshow_kwargs)
image_rust = ax3.imshow(data, **imshow_kwargs)
image_diff = ax4.imshow(data, **imshow_kwargs)
ax_slider = plt.axes((0.25, 0.05, 0.65, 0.03))
ax1.set_title("Data Input")
ax2.set_title("Data Filtered (Python)")
ax3.set_title("Data Filtered (Rust)")
ax4.set_title("Data Residual")
exp_slider = Slider(
ax=ax_slider,
label="Exponent",
valmin=0,
valmax=1.0,
valinit=exponent,
orientation="horizontal",
)
def update(val):
global data
# data = data_bp
data_filtered_py = goldstein_filter(
data, window_size=window_size, exponent=val, overlap=overlap
)
data_filtered = goldstein_filter_rust(
data,
window_size=window_size,
exponent=val,
overlap=overlap,
adaptive_weights=adaptive_weights,
)
image_python.set_data(data_filtered_py)
image_python.set_norm(Normalize(*r(data_filtered_py)))
image_python.set_norm(None)
image_rust.set_data(data_filtered)
image_rust.set_norm(Normalize(*r(data_filtered)))
image_rust.set_norm(None)
norm_diff = normalize_diff(data, data_filtered)
image_diff.set_data(norm_diff)
image_diff.set_norm(Normalize(*r(norm_diff)))
image_diff.set_norm(None)
fig.canvas.draw_idle()
update(exponent)
exp_slider.on_changed(update)
plt.show()
def plot_taper():
window = triangular_taper(32, 4)
plt.imshow(window)
plt.show()
if __name__ == "__main__":
import argparse
import pathlib
parser = argparse.ArgumentParser()
parser.add_argument("npy_file", type=pathlib.Path)
args = parser.parse_args()
data = num.load(args.npy_file).astype(num.float32)
plot_goldstein(data)

246
src/lib.rs

@ -40,23 +40,201 @@ fn triangular_taper(window_size: usize, plateau: usize) -> Array2<f32> {
.slice(s![1..-1])
.to_owned();
let mut taper_both = Array1::<f32>::ones(window_size);
taper_both.slice_mut(s![..ramp_size]).assign(&ramp);
taper_both
let mut taper = Array1::<f32>::ones(window_size);
taper.slice_mut(s![..ramp_size]).assign(&ramp);
taper
.slice_mut(s![window_size - ramp_size..])
.assign(&ramp.slice(s![..;-1]));
let taper = &taper_both * &taper_both.slice(s![.., NewAxis]);
let taper = &taper * &taper.slice(s![.., NewAxis]);
taper
}
fn triangular_taper_2d(window_size: (usize, usize), plateau: (usize, usize)) -> Array2<f32> {
assert!(
plateau.0 < window_size.0 && plateau.1 < window_size.1,
"Plateau cannot be larger than size."
);
assert!(
window_size.0 as f32 % 2.0 == 0. && window_size.1 as f32 % 2.0 == 0.,
"Window size has to be even"
);
let ramp_size = (
(window_size.0 - plateau.0) / 2,
(window_size.1 - plateau.1) / 2,
);
// Crop 0.0 and 1.0
let ramp = (
Array1::<f32>::linspace(0., 1., ramp_size.0 + 2)
.slice(s![1..-1])
.to_owned(),
Array1::<f32>::linspace(0., 1., ramp_size.1 + 2)
.slice(s![1..-1])
.to_owned(),
);
let mut taper = (
Array1::<f32>::ones(window_size.0).to_owned(),
Array1::<f32>::ones(window_size.1).to_owned(),
);
taper.0.slice_mut(s![..ramp_size.0]).assign(&ramp.0);
taper
.0
.slice_mut(s![window_size.0 - ramp_size.0..])
.assign(&ramp.0.slice(s![..;-1]));
taper.1.slice_mut(s![..ramp_size.1]).assign(&ramp.1);
taper
.1
.slice_mut(s![window_size.1 - ramp_size.1..])
.assign(&ramp.1.slice(s![..;-1]));
let window = &taper.1 * &taper.0.slice(s![.., NewAxis]);
window
}
fn goldstein_filter_rect(
data: ArrayView2<f32>,
window_size: (usize, usize),
overlap: (usize, usize),
exponent: f32,
normalize_power: bool,
) -> Array2<f32> {
assert!(
window_size.0 > 4 && window_size.1 > 4,
"Bad window_size: {:?}. window_size has to be base 2 and > 4.",
window_size
);
assert!(
overlap.0 < (window_size.0 / 2),
"overlap {} is too large. Maximum overlap: {}",
overlap.0,
window_size.0 / 2 - 1
);
assert!(
overlap.1 < (window_size.1 / 2),
"overlap {} is too large. Maximum overlap: {}",
overlap.1,
window_size.1 / 2 - 1
);
let window_stride = (window_size.0 - overlap.0, window_size.1 - overlap.1);
let window_non_overlap = (window_size.0 - 2 * overlap.0, window_size.1 - 2 * overlap.1);
let window_px = window_size.0 * window_size.1;
let mut data_padded = Array2::<f32>::zeros((
data.nrows() + 2 * window_stride.0,
data.ncols() + 2 * window_stride.1,
));
data_padded
.slice_mut(s![
window_stride.0..data.nrows() + window_stride.0,
window_stride.1..data.ncols() + window_stride.1
])
.assign(&data);
let npx_x = data_padded.nrows();
let npx_y = data_padded.ncols();
let nx = npx_x / window_stride.0;
let ny = npx_y / window_stride.1;
let window_shape = (window_size.0, window_size.1);
let taper = triangular_taper_2d(window_size, window_non_overlap);
let mut frames = HashMap::with_capacity(nx * ny);
for ix in 0..nx {
let px_x_beg = cmp::min(ix * window_stride.0, npx_x);
let px_x_end = cmp::min(px_x_beg + window_size.0, npx_x);
for iy in 0..ny {
let px_y_beg = cmp::min(iy * window_stride.1, npx_y);
let px_y_end = cmp::min(px_y_beg + window_size.1, npx_y);
let tile = Tile {
px_x: (px_x_beg, px_x_end),
px_y: (px_y_beg, px_y_end),
ix: ix,
iy: iy,
};
let window_data = data_padded
.slice(s![px_x_beg..px_x_end, px_y_beg..px_y_end])
.to_owned();
frames.insert(tile, window_data);
}
}
let fft2_r2c = R2CPlan32::aligned(&[window_size.0, window_size.1], Flag::MEASURE).unwrap();
let fft2_c2r = C2RPlan32::aligned(&[window_size.0, window_size.1], Flag::MEASURE).unwrap();
frames.par_iter_mut().for_each(|(_tile, window_data)| {
let frame_shape = [window_data.shape()[0], window_data.shape()[1]];
let fft_size = window_size.0 * (window_size.1 / 2 + 1);
let mut window_data_fft = AlignedVec::new(fft_size);
let mut power_spec = Array1::<f32>::default(fft_size);
if frame_shape != [window_size.0, window_size.1] {
let mut window_padded = Array2::<f32>::zeros(window_shape);
window_padded
.slice_mut(s![..frame_shape[0], ..frame_shape[1]])
.assign(window_data);
*window_data = window_padded;
}
let mut window_data_slice = window_data.as_slice_mut().unwrap();
fft2_r2c
.r2c(&mut window_data_slice, &mut window_data_fft)
.unwrap();
let mut power_max: f32 = 0.;
for (ipx, px) in window_data_fft.iter().enumerate() {
power_spec[ipx] = px.norm();
power_max = power_max.max(power_spec[ipx]);
}
if normalize_power {
power_spec /= power_max;
}
// Filter the spectrum
for (ipx, px) in window_data_fft.iter_mut().enumerate() {
*px *= power_spec[ipx].powf(exponent);
}
fft2_c2r
.c2r(&mut window_data_fft, &mut window_data_slice)
.unwrap();
// Normalize fft
*window_data /= window_px as f32;
*window_data *= &taper;
if frame_shape != [window_size.0, window_size.1] {
*window_data = window_data
.slice(s![..frame_shape[0], ..frame_shape[1]])
.to_owned();
}
});
let mut filtered_data = Array2::<f32>::zeros(data_padded.raw_dim());
for (tile, window_data) in frames.iter() {
let mut filtered_slice =
filtered_data.slice_mut(s![tile.px_x.0..tile.px_x.1, tile.px_y.0..tile.px_y.1]);
filtered_slice += window_data;
}
filtered_data
.slice(s![
window_stride.0..data.nrows() + window_stride.0,
window_stride.1..data.ncols() + window_stride.1
])
.to_owned()
}
fn goldstein_filter(
data: ArrayView2<f32>,
window_size: usize,
overlap: usize,
exponent: f32,
adaptive_weights: bool,
normalize_power: bool,
) -> Array2<f32> {
assert!(
((window_size as f64).log2() % 1.0 == 0.) && window_size > 4,
@ -117,8 +295,10 @@ fn goldstein_filter(
let fft2_c2r = C2RPlan32::aligned(&[window_size, window_size], Flag::MEASURE).unwrap();
frames.par_iter_mut().for_each(|(_tile, window_data)| {
let frame_shape = [window_data.shape()[0], window_data.shape()[1]];
let fft_size = (window_size / 2 + 1) * window_size;
let mut window_data_fft = AlignedVec::new(fft_size);
let mut power_spec = Array1::<f32>::default(fft_size);
let mut window_data_fft = AlignedVec::new((window_size / 2 + 1) * window_size);
if frame_shape != [window_size, window_size] {
let mut window_padded = Array2::<f32>::zeros(window_shape);
window_padded
@ -132,27 +312,26 @@ fn goldstein_filter(
.r2c(&mut window_data_slice, &mut window_data_fft)
.unwrap();
// Scale the spectrum
let mut weight: f32;
let mut window_power = 0.;
for px in window_data_fft.iter_mut() {
weight = px.norm().powf(exponent);
window_power += px.norm().powf(2.);
*px *= weight;
let mut power_max: f32 = 0.;
for (ipx, px) in window_data_fft.iter().enumerate() {
power_spec[ipx] = px.norm();
power_max = power_max.max(power_spec[ipx]);
}
if normalize_power {
power_spec /= power_max;
}
if adaptive_weights {
for px in window_data_fft.iter_mut() {
*px /= window_power / window_px as f32;
}
// Filter the spectrum and normalize
for (ipx, px) in window_data_fft.iter_mut().enumerate() {
*px *= power_spec[ipx].powf(exponent);
*px /= window_px as f32;
}
fft2_c2r
.c2r(&mut window_data_fft, &mut window_data_slice)
.unwrap();
// Normalize fft
*window_data /= window_px as f32;
*window_data *= &taper;
if frame_shape != [window_size, window_size] {
@ -183,10 +362,10 @@ fn rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[pyo3(name = "triangular_taper")]
fn triangular_taper_wrapper<'py>(
py: Python<'py>,
window_size: usize,
plateau: usize,
window_size: (usize, usize),
plateau: (usize, usize),
) -> &'py PyArray2<f32> {
let taper = triangular_taper(window_size, plateau).into_pyarray(py);
let taper = triangular_taper_2d(window_size, plateau).into_pyarray(py);
taper
}
@ -198,12 +377,29 @@ fn rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
window_size: usize,
overlap: usize,
exponent: f32,
adaptive_weights: bool,
normalize_power: bool,
) -> &'py PyArray2<f32> {
let data_array = data.as_array();
let result = goldstein_filter(data_array, window_size, overlap, exponent, adaptive_weights)
let result = goldstein_filter(data_array, window_size, overlap, exponent, normalize_power)
.into_pyarray(py);
result
}
#[pyfn(m)]
#[pyo3(name = "goldstein_filter_rect")]
fn goldstein_filter_rect_wrapper<'py>(
py: Python<'py>,
data: PyReadonlyArray2<'py, f32>,
window_size: (usize, usize),
overlap: (usize, usize),
exponent: f32,
normalize_power: bool,
) -> &'py PyArray2<f32> {
let data_array = data.as_array();
let result =
goldstein_filter_rect(data_array, window_size, overlap, exponent, normalize_power)
.into_pyarray(py);
result
}
Ok(())
}

33
tests/test_goldstein.py

@ -24,15 +24,14 @@ def test_taper():
assert window[32 // 2, 32 // 2] == 1.0
assert window.shape == (32, 32)
taper_rust = rust.triangular_taper(32, 4)
window_rust = taper_rust["full"]
num.testing.assert_almost_equal(window, window_rust)
taper_rust = rust.triangular_taper((32, 32), (4, 4))
num.testing.assert_almost_equal(window, taper_rust)
def test_plot_taper():
import matplotlib.pyplot as plt
taper_rust = rust.triangular_taper(32, 4)
taper_rust = rust.triangular_taper((32, 64), (4, 10))
fig = plt.figure()
ax = fig.gca()
@ -51,13 +50,25 @@ def test_benchmark_goldstein_rust(benchmark, data_big):
benchmark(rust.goldstein_filter, data_big, 32, 14, 0.5, False)
def test_goldstein(data):
import matplotlib.pyplot as plt
def test_goldstein_rust(data):
filtered_data_rust = rust.goldstein_filter(
data, 32, 14, exponent=0.0, normalize_power=False
)
filtered_data = goldstein.goldstein_filter(data, 32, 14, 0.5)
filtered_data_rust = rust.goldstein_filter(data, 32, 14, 0.5, False)
filtered_data_rust_rect = rust.goldstein_filter_rect(
data, (32, 32), (14, 14), exponent=0.0, normalize_power=False
)
num.testing.assert_almost_equal(filtered_data_rust_rect, filtered_data_rust)
num.testing.assert_allclose(data, filtered_data_rust, rtol=1.0)
plt.imshow(filtered_data_rust)
plt.show()
filtered_data_rust_rect = rust.goldstein_filter_rect(
data, (32, 16), (14, 7), exponent=0.0, normalize_power=False
)
filtered_data_rust_rect = rust.goldstein_filter_rect(
data, (32, 128), (14, 56), exponent=0.0, normalize_power=False
)
# num.testing.assert_allclose(filtered_data_rust, filtered_data)
filtered_data_rust_rect = rust.goldstein_filter_rect(
data, (32, 200), (14, 80), exponent=0.0, normalize_power=False
)

Loading…
Cancel
Save