You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
182 lines
5.1 KiB
Python
182 lines
5.1 KiB
Python
2 years ago
|
from __future__ import annotations
|
||
|
|
||
|
import os.path as op
|
||
|
from functools import lru_cache
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as num
|
||
|
from matplotlib.widgets import Slider
|
||
|
from scipy import signal
|
||
2 years ago
|
from matplotlib.colors import Normalize
|
||
|
from scipy.signal import butter, lfilter
|
||
2 years ago
|
|
||
2 years ago
|
from .utils import timeit
|
||
2 years ago
|
|
||
|
data = num.load(op.join(op.dirname(__file__), "data", "data-DAS-gfz2020wswf.npy"))
|
||
2 years ago
|
|
||
|
|
||
|
def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
|
||
|
b, a = butter(order, (lowcut, highcut), btype="bandpass", fs=fs)
|
||
|
y = lfilter(b, a, data, axis=0)
|
||
|
return y
|
||
2 years ago
|
|
||
|
|
||
|
@lru_cache
|
||
|
def triangular_taper(size: int, plateau: int):
|
||
|
if plateau > size:
|
||
|
raise ValueError("Plateau cannot be larger than size.")
|
||
|
if size % 2 or plateau % 2:
|
||
|
raise ValueError("Size and plateau have to be even.")
|
||
|
|
||
|
ramp_size = int((size - plateau) / 2)
|
||
|
ramp = num.linspace(0.0, 1.0, ramp_size)
|
||
|
window = num.ones(size)
|
||
|
window[:ramp_size] = ramp
|
||
|
window[size - ramp_size :] = ramp[::-1]
|
||
|
return window * window[:, num.newaxis]
|
||
|
|
||
|
|
||
|
def data_coherency(data):
|
||
|
ntraces = data.shape[0]
|
||
|
coherency = 0.0
|
||
|
for itr in range(ntraces - 1):
|
||
|
coherency += signal.coherence(data[itr], data[itr + 1])[1].mean()
|
||
|
|
||
|
return coherency / (ntraces - 1)
|
||
|
|
||
|
|
||
2 years ago
|
@timeit
|
||
2 years ago
|
def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.3):
|
||
|
if num.log2(window_size) % 1.0 or window_size < 4:
|
||
|
raise ValueError("window_size has to be pow(2) and > 4.")
|
||
|
if overlap > window_size / 2 - 1:
|
||
2 years ago
|
raise ValueError("Overlap is too large. Maximum overlap: window_size / 2 - 1.")
|
||
2 years ago
|
|
||
|
window_stride = window_size - overlap
|
||
|
window_non_overlap = window_size - 2 * overlap
|
||
|
|
||
|
npx_x, npx_y = data.shape
|
||
|
nwin_x = npx_x // window_stride
|
||
|
nwin_y = npx_y // window_stride
|
||
|
if nwin_x % 1 or nwin_y % 1:
|
||
|
raise ValueError("Padding does not match desired data shape")
|
||
|
|
||
|
filtered_data = num.zeros_like(data)
|
||
2 years ago
|
taper = triangular_taper(window_size, window_non_overlap)
|
||
2 years ago
|
|
||
|
for iwin_x in range(nwin_x):
|
||
|
px_x = iwin_x * window_stride
|
||
|
slice_x = slice(px_x, px_x + window_size)
|
||
|
for iwin_y in range(nwin_y):
|
||
|
px_y = iwin_y * window_stride
|
||
|
slice_y = slice(px_y, px_y + window_size)
|
||
|
|
||
|
window_data = data[slice_x, slice_y]
|
||
|
window_fft = num.fft.rfft2(window_data)
|
||
2 years ago
|
weights = num.abs(window_fft) ** exponent
|
||
2 years ago
|
# Optional
|
||
2 years ago
|
# weights = signal.medfilt2d(weights, kernel_size=3)
|
||
2 years ago
|
# window_coherency = data_coherency(window_data)
|
||
2 years ago
|
window_fft *= weights
|
||
2 years ago
|
window_filtered = num.fft.irfft2(window_fft)
|
||
|
|
||
2 years ago
|
taper_this = taper[: window_filtered.shape[0], : window_filtered.shape[1]]
|
||
2 years ago
|
filtered_data[
|
||
|
px_x : px_x + window_filtered.shape[0],
|
||
|
px_y : px_y + window_filtered.shape[1],
|
||
|
] += (
|
||
2 years ago
|
window_filtered * taper_this
|
||
2 years ago
|
)
|
||
|
|
||
|
return filtered_data
|
||
|
|
||
|
|
||
|
def normalize_diff(data, filtered_data):
|
||
2 years ago
|
return Normalize()(data) - Normalize()(filtered_data)
|
||
2 years ago
|
|
||
|
|
||
2 years ago
|
def plot_goldstein(data):
|
||
2 years ago
|
def r(data):
|
||
|
v = num.std(data)
|
||
|
return -v, v
|
||
|
|
||
2 years ago
|
from .rust import goldstein_filter as goldstein_filter_rust
|
||
|
|
||
|
goldstein_filter_rust = timeit(goldstein_filter_rust)
|
||
|
|
||
2 years ago
|
window_size = 32
|
||
|
overlap = 14
|
||
2 years ago
|
exponent = 0.5
|
||
2 years ago
|
|
||
2 years ago
|
imshow_kwargs = dict(aspect=5, cmap="viridis", vmin=-1.0, vmax=1.0)
|
||
2 years ago
|
|
||
2 years ago
|
fig, (ax1, ax3, ax4) = plt.subplots(1, 3, sharex=True, sharey=True)
|
||
|
ax1.imshow(data, **imshow_kwargs)
|
||
|
# image_python = ax2.imshow(data, **imshow_kwargs)
|
||
|
image_rust = ax3.imshow(data, **imshow_kwargs)
|
||
|
image_diff = ax4.imshow(data, **imshow_kwargs)
|
||
2 years ago
|
|
||
2 years ago
|
ax_slider = plt.axes((0.25, 0.05, 0.65, 0.03))
|
||
2 years ago
|
|
||
|
ax1.set_title("Data Input")
|
||
2 years ago
|
# ax2.set_title("Data Filtered (Python)")
|
||
2 years ago
|
ax3.set_title("Data Filtered (Rust)")
|
||
|
ax4.set_title("Data Residual")
|
||
2 years ago
|
|
||
|
exp_slider = Slider(
|
||
|
ax=ax_slider,
|
||
|
label="Exponent",
|
||
|
valmin=0,
|
||
|
valmax=1.0,
|
||
|
valinit=exponent,
|
||
|
orientation="horizontal",
|
||
|
)
|
||
|
|
||
|
def update(val):
|
||
|
global data
|
||
2 years ago
|
# data = data_bp
|
||
2 years ago
|
|
||
2 years ago
|
# data_filtered = goldstein_filter(
|
||
|
# data, window_size=window_size, exponent=val, overlap=overlap
|
||
|
# )
|
||
2 years ago
|
|
||
2 years ago
|
data_filtered = goldstein_filter_rust(
|
||
2 years ago
|
data, window_size=window_size, exponent=val, overlap=overlap
|
||
|
)
|
||
|
|
||
2 years ago
|
# image_python.set_data(data_filtered)
|
||
|
# image_python.set_norm(Normalize(*r(data_filtered)))
|
||
2 years ago
|
|
||
2 years ago
|
image_rust.set_data(data_filtered)
|
||
2 years ago
|
image_rust.set_norm(None)
|
||
2 years ago
|
|
||
2 years ago
|
norm_diff = normalize_diff(data, data_filtered)
|
||
|
image_diff.set_data(norm_diff)
|
||
2 years ago
|
image_diff.set_norm(None)
|
||
|
|
||
2 years ago
|
fig.canvas.draw_idle()
|
||
|
|
||
2 years ago
|
update(exponent)
|
||
|
|
||
2 years ago
|
exp_slider.on_changed(update)
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
def plot_taper():
|
||
|
window = triangular_taper(32, 4)
|
||
|
plt.imshow(window)
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
2 years ago
|
import argparse
|
||
|
import pathlib
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("npy_file", type=pathlib.Path)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
data = num.load(args.npy_file).astype(num.float32)
|
||
|
|
||
|
plot_goldstein(data)
|