Marius Isken 1 month ago
parent
commit
e3d2ee4809
  1. 8
      Cargo.lock
  2. 1
      lightguide/__init__.py
  3. 48
      lightguide/goldstein.py
  4. 20
      src/lib.rs
  5. 15
      tests/test_goldstein.py

8
Cargo.lock

@ -522,9 +522,9 @@ checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5"
[[package]]
name = "proc-macro2"
version = "1.0.30"
version = "1.0.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edc3358ebc67bc8b7fa0c007f945b0b18226f78437d61bec735a9eb96b61ee70"
checksum = "ba508cc11742c0dc5c1659771673afbab7a0efab23aa17e854cbab0837ed0b43"
dependencies = [
"unicode-xid",
]
@ -659,9 +659,9 @@ checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309"
[[package]]
name = "syn"
version = "1.0.80"
version = "1.0.81"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d010a1623fbd906d51d650a9916aaefc05ffa0e4053ff7fe601167f3e715d194"
checksum = "f2afee18b8beb5a596ecb4a2dce128c719b4ba399d34126b9e4396e3f9860966"
dependencies = [
"proc-macro2",
"quote",

1
lightguide/__init__.py

@ -0,0 +1 @@
from . import rust

48
lightguide/goldstein.py

@ -45,7 +45,6 @@ def data_coherency(data):
return coherency / (ntraces - 1)
@timeit
def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.3):
if num.log2(window_size) % 1.0 or window_size < 4:
raise ValueError("window_size has to be pow(2) and > 4.")
@ -76,17 +75,17 @@ def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.
weights = num.abs(window_fft) ** exponent
# Optional
# weights = signal.medfilt2d(weights, kernel_size=3)
# weights /= weights.sum()
# window_coherency = data_coherency(window_data)
window_fft *= weights
window_filtered = num.fft.irfft2(window_fft)
taper_this = taper[: window_filtered.shape[0], : window_filtered.shape[1]]
# window_fft /= weights.sum()
window_flt = num.fft.irfft2(window_fft)
taper_this = taper[: window_flt.shape[0], : window_flt.shape[1]]
window_flt *= taper_this
filtered_data[
px_x : px_x + window_filtered.shape[0],
px_y : px_y + window_filtered.shape[1],
] += (
window_filtered * taper_this
)
px_x : px_x + window_flt.shape[0],
px_y : px_y + window_flt.shape[1],
] += window_flt
return filtered_data
@ -97,29 +96,28 @@ def normalize_diff(data, filtered_data):
def plot_goldstein(data):
def r(data):
v = num.std(data)
v = num.std(data) * 2
return -v, v
from .rust import goldstein_filter as goldstein_filter_rust
goldstein_filter_rust = timeit(goldstein_filter_rust)
window_size = 32
overlap = 14
window_size = 64
overlap = 30
exponent = 0.5
adaptive_weights = False
imshow_kwargs = dict(aspect=5, cmap="viridis", vmin=-1.0, vmax=1.0)
fig, (ax1, ax3, ax4) = plt.subplots(1, 3, sharex=True, sharey=True)
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharex=True, sharey=True)
ax1.imshow(data, **imshow_kwargs)
# image_python = ax2.imshow(data, **imshow_kwargs)
image_python = ax2.imshow(data, **imshow_kwargs)
image_rust = ax3.imshow(data, **imshow_kwargs)
image_diff = ax4.imshow(data, **imshow_kwargs)
ax_slider = plt.axes((0.25, 0.05, 0.65, 0.03))
ax1.set_title("Data Input")
# ax2.set_title("Data Filtered (Python)")
ax2.set_title("Data Filtered (Python)")
ax3.set_title("Data Filtered (Rust)")
ax4.set_title("Data Residual")
@ -136,22 +134,26 @@ def plot_goldstein(data):
global data
# data = data_bp
# data_filtered = goldstein_filter(
# data, window_size=window_size, exponent=val, overlap=overlap
# )
data_filtered_py = goldstein_filter(
data, window_size=window_size, exponent=val, overlap=overlap
)
data_filtered = goldstein_filter_rust(
data, window_size=window_size, exponent=val, overlap=overlap
data, window_size=window_size, exponent=val, overlap=overlap,
adaptive_weights=adaptive_weights
)
# image_python.set_data(data_filtered)
# image_python.set_norm(Normalize(*r(data_filtered)))
image_python.set_data(data_filtered_py)
image_python.set_norm(Normalize(*r(data_filtered_py)))
image_python.set_norm(None)
image_rust.set_data(data_filtered)
image_rust.set_norm(Normalize(*r(data_filtered)))
image_rust.set_norm(None)
norm_diff = normalize_diff(data, data_filtered)
image_diff.set_data(norm_diff)
image_diff.set_norm(Normalize(*r(norm_diff)))
image_diff.set_norm(None)
fig.canvas.draw_idle()

20
src/lib.rs

@ -35,6 +35,7 @@ fn goldstein_filter(
window_size: usize,
overlap: usize,
exponent: f32,
adaptive_weights: bool,
) -> Array2<f32> {
assert!(
((window_size as f64).log2() % 1.0 == 0.) && window_size > 4,
@ -48,6 +49,7 @@ fn goldstein_filter(
let window_stride = window_size - overlap;
let window_non_overlap = window_size - 2 * overlap;
let window_px = window_size * window_size;
let npx_x = data.nrows();
let npx_y = data.ncols();
@ -91,8 +93,19 @@ fn goldstein_filter(
.unwrap();
// Scale the spectrum
let mut weight: f32;
let mut weight_sum = 0.;
for px in window_data_fft.iter_mut() {
*px *= px.norm().powf(exponent);
weight = px.norm().powf(exponent);
weight_sum += px.norm();
*px *= weight;
}
if adaptive_weights {
for px in window_data_fft.iter_mut() {
*px *= weight_sum.powf(2);
}
}
fft2_c2r
@ -100,6 +113,7 @@ fn goldstein_filter(
.unwrap();
//Apply the taper
*window_data /= window_px as f32;
*window_data *= &taper;
if frame_shape != [window_size, window_size] {
@ -141,9 +155,11 @@ fn rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
window_size: usize,
overlap: usize,
exponent: f32,
adaptive_weights: bool,
) -> PyResult<&'py PyArray2<f32>> {
let data_array = data.as_array();
let result = goldstein_filter(data_array, window_size, overlap, exponent).into_pyarray(py);
let result = goldstein_filter(data_array, window_size, overlap, exponent, adaptive_weights)
.into_pyarray(py);
Ok(result)
}
Ok(())

15
tests/test_goldstein.py

@ -15,7 +15,7 @@ def data():
@pytest.fixture
def data_big():
n = 8192
n = 2048
return num.random.uniform(size=(n, n)).astype(num.float32)
@ -28,16 +28,19 @@ def test_taper():
num.testing.assert_almost_equal(window, window_rust)
# def test_benchmark_goldstein(benchmark, data):
# benchmark(goldstein.goldstein_filter, data, 32, 14, 0.5)
@pytest.mark.skip
def test_benchmark_goldstein(benchmark, data_big):
benchmark(goldstein.goldstein_filter, data_big, 32, 14, 0.5)
@pytest.mark.skip
def test_benchmark_goldstein_rust(benchmark, data_big):
benchmark(rust.goldstein_filter, data_big, 32, 14, 0.5)
benchmark(rust.goldstein_filter, data_big, 32, 14, 0.5, False)
def asdtest_goldstein(data):
@pytest.mark.skip
def test_goldstein(data):
filtered_data = goldstein.goldstein_filter(data, 32, 14, 0.5)
filtered_data_rust = rust.goldstein_filter(data, 32, 14, 0.5)
filtered_data_rust = rust.goldstein_filter(data, 32, 14, 0.5, False)
num.testing.assert_allclose(filtered_data_rust, filtered_data)

Loading…
Cancel
Save