Browse Source

adding rust

master
Marius Isken 1 month ago
parent
commit
eece87ad77
  1. 209
      Cargo.lock
  2. 2
      Cargo.toml
  3. 57
      pyrocko_das/goldstein.py
  4. 14
      pyrocko_das/utils.py
  5. 1
      requirements-dev.txt
  6. 101
      src/lib.rs
  7. 30
      tests/test_goldstein.py

209
Cargo.lock

@ -26,6 +26,65 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "crossbeam-channel"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4"
dependencies = [
"cfg-if 1.0.0",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e"
dependencies = [
"cfg-if 1.0.0",
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd"
dependencies = [
"cfg-if 1.0.0",
"crossbeam-utils",
"lazy_static",
"memoffset",
"scopeguard",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db"
dependencies = [
"cfg-if 1.0.0",
"lazy_static",
]
[[package]]
name = "either"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "indoc"
version = "0.3.6"
@ -58,6 +117,12 @@ dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.104"
@ -82,6 +147,15 @@ dependencies = [
"rawpointer",
]
[[package]]
name = "memoffset"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9"
dependencies = [
"autocfg",
]
[[package]]
name = "ndarray"
version = "0.15.3"
@ -89,10 +163,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08e854964160a323e65baa19a0b1a027f76d590faba01f05c0cbc3187221a8c9"
dependencies = [
"matrixmultiply",
"num-complex",
"num-complex 0.4.0",
"num-integer",
"num-traits",
"rawpointer",
"rayon",
]
[[package]]
name = "ndrustfft"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13e5807dd25d129b57036ec928f372f0c14d3fc0415ce57afddf5d990c8ba002"
dependencies = [
"ndarray",
"num-traits",
"realfft",
"rustdct",
"rustfft 6.0.1",
]
[[package]]
name = "num-complex"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "747d632c0c558b87dbabbe6a82f3b4ae03720d0646ac5b7b4dae89394be5f2c5"
dependencies = [
"num-traits",
]
[[package]]
@ -123,6 +220,16 @@ dependencies = [
"autocfg",
]
[[package]]
name = "num_cpus"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
dependencies = [
"hermit-abi",
"libc",
]
[[package]]
name = "numpy"
version = "0.14.1"
@ -132,7 +239,7 @@ dependencies = [
"cfg-if 0.1.10",
"libc",
"ndarray",
"num-complex",
"num-complex 0.4.0",
"num-traits",
"pyo3",
]
@ -187,6 +294,15 @@ dependencies = [
"proc-macro-hack",
]
[[package]]
name = "primal-check"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01419cee72c1a1ca944554e23d83e483e1bccf378753344e881de28b5487511d"
dependencies = [
"num-integer",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.19"
@ -255,8 +371,10 @@ name = "pyrocko-das"
version = "0.1.0"
dependencies = [
"ndarray",
"ndrustfft",
"numpy",
"pyo3",
"rayon",
]
[[package]]
@ -274,6 +392,40 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06aca804d41dbc8ba42dfd964f0d01334eceb64314b9ecf7c5fad5188a06d90"
dependencies = [
"autocfg",
"crossbeam-deque",
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-utils",
"lazy_static",
"num_cpus",
]
[[package]]
name = "realfft"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7695c87f31dc3644760f23fb59a3fed47659703abf76cf2d111f03b9e712342"
dependencies = [
"rustfft 6.0.1",
]
[[package]]
name = "redox_syscall"
version = "0.2.10"
@ -283,6 +435,43 @@ dependencies = [
"bitflags",
]
[[package]]
name = "rustdct"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fadcb505b98aa64da1dadb1498b912e3642aae4606623cb3ae952cd8da33f80d"
dependencies = [
"rustfft 5.1.1",
]
[[package]]
name = "rustfft"
version = "5.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1869bb2a6ff77380d52ff4bc631f165637035a55855c76aa462c85474dadc42f"
dependencies = [
"num-complex 0.3.1",
"num-integer",
"num-traits",
"primal-check",
"strength_reduce",
"transpose",
]
[[package]]
name = "rustfft"
version = "6.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1d089e5c57521629a59f5f39bca7434849ff89bd6873b521afe389c1c602543"
dependencies = [
"num-complex 0.4.0",
"num-integer",
"num-traits",
"primal-check",
"strength_reduce",
"transpose",
]
[[package]]
name = "scopeguard"
version = "1.1.0"
@ -295,6 +484,12 @@ version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309"
[[package]]
name = "strength_reduce"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3ff2f71c82567c565ba4b3009a9350a96a7269eaa4001ebedae926230bc2254"
[[package]]
name = "syn"
version = "1.0.80"
@ -306,6 +501,16 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "transpose"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95f9c900aa98b6ea43aee227fd680550cdec726526aab8ac801549eadb25e39f"
dependencies = [
"num-integer",
"strength_reduce",
]
[[package]]
name = "unicode-xid"
version = "0.2.2"

2
Cargo.toml

@ -12,6 +12,8 @@ crate-type = ["cdylib"]
[dependencies]
numpy = "0.14.1"
ndarray = "0.15.3"
ndrustfft = "0.2.1"
rayon = "1.2.0"
[dependencies.pyo3]
version = "0.14.5"

57
pyrocko_das/goldstein.py

@ -1,7 +1,6 @@
from __future__ import annotations
import os.path as op
import time
from functools import lru_cache
import matplotlib.pyplot as plt
@ -9,7 +8,7 @@ import numpy as num
from matplotlib.widgets import Slider
from scipy import signal
from .rust import rust_func as goldstein_rust
from .utils import timeit
data = num.load(op.join(op.dirname(__file__), "data", "data-DAS-gfz2020wswf.npy"))
@ -38,6 +37,7 @@ def data_coherency(data):
return coherency / (ntraces - 1)
@timeit
def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.3):
if num.log2(window_size) % 1.0 or window_size < 4:
raise ValueError("window_size has to be pow(2) and > 4.")
@ -54,6 +54,7 @@ def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.
raise ValueError("Padding does not match desired data shape")
filtered_data = num.zeros_like(data)
taper = triangular_taper(window_size, window_non_overlap)
for iwin_x in range(nwin_x):
px_x = iwin_x * window_stride
@ -64,21 +65,19 @@ def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.
window_data = data[slice_x, slice_y]
window_fft = num.fft.rfft2(window_data)
window_fft_abs = num.abs(window_fft) / window_fft.size
window_fft_abs = num.abs(window_fft)
# Optional
# window_fft_abs = signal.medfilt2d(window_fft_abs, kernel_size=3)
# window_coherency = data_coherency(window_data)
window_fft *= window_fft_abs ** exponent
window_filtered = num.fft.irfft2(window_fft)
taper = triangular_taper(window_size, window_non_overlap)
taper = taper[: window_filtered.shape[0], : window_filtered.shape[1]]
taper_this = taper[: window_filtered.shape[0], : window_filtered.shape[1]]
filtered_data[
px_x : px_x + window_filtered.shape[0],
px_y : px_y + window_filtered.shape[1],
] += (
window_filtered * taper
window_filtered * taper_this
)
return filtered_data
@ -92,25 +91,31 @@ def normalize_diff(data, filtered_data):
def plot_goldstein():
exponent = 0.5
from .rust import goldstein_filter as goldstein_filter_rust
goldstein_filter_rust = timeit(goldstein_filter_rust)
window_size = 64
overlap = 30
t = time.time()
data_filtered = goldstein_filter(
data, window_size=window_size, exponent=exponent, overlap=overlap
)
print(time.time() - t)
exponent = 0.5
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=True, sharey=True)
ax1.imshow(data, aspect=5)
image = ax2.imshow(data_filtered, aspect=5)
image_diff = ax3.imshow(normalize_diff(data, data_filtered), aspect=5)
data_filtered = goldstein_filter(data, window_size, overlap, exponent)
data_filtered_rust = goldstein_filter_rust(data, window_size, overlap, exponent)
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharex=True, sharey=True)
ax1.imshow(data, aspect=5, cmap="viridis")
image_python = ax2.imshow(data_filtered, aspect=5, cmap="viridis")
image_rust = ax3.imshow(data_filtered_rust, aspect=5, cmap="viridis")
image_diff = ax4.imshow(
normalize_diff(data, data_filtered), aspect=5, cmap="viridis"
)
ax_slider = plt.axes((0.25, 0.1, 0.65, 0.03))
ax1.set_title("Data Input")
ax2.set_title("Data Filtered")
ax3.set_title("Data Residual")
ax2.set_title("Data Filtered (Python)")
ax3.set_title("Data Filtered (Rust)")
ax4.set_title("Data Residual")
exp_slider = Slider(
ax=ax_slider,
@ -127,8 +132,16 @@ def plot_goldstein():
data_filtered = goldstein_filter(
data, window_size=window_size, exponent=val, overlap=overlap
)
image.set_data(data_filtered)
image.set_norm(None)
data_filtered_rust = goldstein_filter_rust(
data, window_size=window_size, exponent=val, overlap=overlap
)
image_python.set_data(data_filtered)
image_python.set_norm(None)
image_rust.set_data(data_filtered_rust)
image_rust.set_norm(None)
image_diff.set_data(normalize_diff(data, data_filtered))
image_diff.set_norm(None)
@ -145,5 +158,5 @@ def plot_taper():
if __name__ == "__main__":
print(goldstein_rust())
# print(goldstein_rust())
plot_goldstein()

14
pyrocko_das/utils.py

@ -1,3 +1,6 @@
import time
from functools import wraps
import numpy as num
@ -20,3 +23,14 @@ def traces_to_numpy_and_meta(traces):
meta = tr.meta
return data, AttrDict(meta)
def timeit(func):
@wraps(func)
def wrapper(*args, **kwargs):
t = time.time()
ret = func(*args, **kwargs)
print(func.__qualname__, time.time() - t)
return ret
return wrapper

1
requirements-dev.txt

@ -1,3 +1,4 @@
pytest-benchmark
pytest
setuptools-rust
flake8

101
src/lib.rs

@ -1,22 +1,103 @@
use ndarray::Array;
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray2};
use ndarray::prelude::*;
use ndarray::{Array, Array2, ArrayView2};
use ndrustfft::{ndfft_r2c, ndifft_r2c, Complex, R2cFftHandler};
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
use pyo3::prelude::*;
fn triangular_taper(taper_size: usize, plateau: usize) -> Array<f32, _> {
assert!(plateau < taper_size, "Plateau cannot be larger than size.");
assert!(taper_size as f32 % 2.0 != 0., "Size has to be even");
use rayon::prelude::*;
use std::cmp;
fn triangular_taper(window_size: usize, plateau: usize) -> Array2<f32> {
assert!(plateau < window_size, "Plateau cannot be larger than size.");
assert!(window_size as f32 % 2.0 == 0., "Size has to be even");
let ramp_size = (window_size - plateau) / 2;
let ramp = Array::<f32, _>::linspace(0., 1., ramp_size);
let mut taper = Array::<f32, _>::ones(window_size);
taper.slice_mut(s![..ramp_size]).assign(&ramp);
taper
.slice_mut(s![window_size - ramp_size..])
.assign(&ramp.slice(s![..;-1]));
let window = &taper * &taper.slice(s![.., NewAxis]);
window
}
fn goldstein_filter(
data: ArrayView2<f32>,
window_size: usize,
overlap: usize,
exponent: f32,
) -> Array2<f32> {
let window_stride = window_size - overlap;
let window_non_overlap = window_size - 2 * overlap;
let npx_x = data.nrows();
let npx_y = data.ncols();
let nx = npx_x / window_stride;
let ny = npx_y / window_stride;
let mut filtered_data = Array2::<f32>::zeros(data.raw_dim());
let taper = triangular_taper(window_size, window_non_overlap);
for ix in 0..nx {
let px_x_beg = cmp::min(ix * window_stride, npx_x);
let px_x_end = cmp::min(px_x_beg + window_size, npx_x);
for ny in 0..ny {
let px_y_beg = cmp::min(ny * window_stride, npx_y);
let px_y_end = cmp::min(px_y_beg + window_size, npx_y);
let window_data = data.slice(s![px_x_beg..px_x_end, px_y_beg..px_y_end]);
let window_shape = [window_data.shape()[0], window_data.shape()[1]];
let taper_window = taper.slice(s![..window_shape[0], ..window_shape[1]]);
let mut handler = R2cFftHandler::<f32>::new(window_shape[0]);
let mut window_fft =
Array2::<Complex<f32>>::default((window_shape[0] / 2 + 1, window_shape[1]));
let mut window_filtered = Array2::<f32>::default(window_data.raw_dim());
ndfft_r2c(&window_data, &mut window_fft, &mut handler, 0);
window_fft.mapv_inplace(|px| px.scale(px.norm().powf(exponent)));
ndifft_r2c(&window_fft, &mut window_filtered, &mut handler, 0);
window_filtered *= &taper_window;
let mut filtered_slice =
filtered_data.slice_mut(s![px_x_beg..px_x_end, px_y_beg..px_y_end]);
filtered_slice += &window_filtered;
}
}
filtered_data
}
/// A Python module implemented in Rust.
#[pymodule]
fn rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[pyfn(m)]
fn rust_func<'py>(
#[pyo3(name = "triangular_taper")]
fn triangular_taper_wrapper<'py>(
py: Python<'py>,
window_size: usize,
plateau: usize,
) -> PyResult<&'py PyArray2<f32>> {
let taper = triangular_taper(window_size, plateau).into_pyarray(py);
Ok(taper)
}
#[pyfn(m)]
#[pyo3(name = "goldstein_filter")]
fn goldstein_filter_wrapper<'py>(
py: Python<'py>,
data: PyReadonlyArray2<'py, f64>,
) -> PyResult<&'py PyArray1<f64>> {
let x = vec![1., 2., 3.].into_pyarray(py);
Ok(x)
data: PyReadonlyArray2<'py, f32>,
window_size: usize,
overlap: usize,
exponent: f32,
) -> PyResult<&'py PyArray2<f32>> {
let data_array = data.as_array();
let result = goldstein_filter(data_array, window_size, overlap, exponent).into_pyarray(py);
Ok(result)
}
Ok(())
}

30
tests/test_goldstein.py

@ -1,9 +1,12 @@
import numpy as num
import os.path as op
from pyrocko_das.goldstein import triangular_taper, goldstein_rust
import numpy as num
import pytest
from pyrocko_das import goldstein, rust
def get_data():
@pytest.fixture
def data():
import pyrocko_das
das_dir = pyrocko_das.__file__
@ -11,11 +14,24 @@ def get_data():
def test_taper():
window = triangular_taper(32, 4)
window = goldstein.triangular_taper(32, 4)
assert window[32 // 2, 32 // 2] == 1.0
assert window.shape == (32, 32)
window_rust = rust.triangular_taper(32, 4)
num.testing.assert_almost_equal(window, window_rust)
# def test_benchmark_goldstein(benchmark, data):
# benchmark(goldstein.goldstein_filter, data, 32, 14, 0.5)
# def test_benchmark_goldstein_rust(benchmark, data):
# benchmark(rust.goldstein_filter, data, 32, 14, 0.5)
def test_goldstein(data):
filtered_data = goldstein.goldstein_filter(data, 32, 14, 0.5)
filtered_data_rust = rust.goldstein_filter(data, 32, 14, 0.5)
def test_goldstein():
data = get_data().astype(num.float64)
print(goldstein_rust(data))
num.testing.assert_allclose(filtered_data_rust, filtered_data)

Loading…
Cancel
Save