You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

182 lines
5.1 KiB
Python

2 years ago
from __future__ import annotations
import os.path as op
from functools import lru_cache
import matplotlib.pyplot as plt
import numpy as num
from matplotlib.widgets import Slider
from scipy import signal
2 years ago
from matplotlib.colors import Normalize
from scipy.signal import butter, lfilter
2 years ago
2 years ago
from .utils import timeit
2 years ago
data = num.load(op.join(op.dirname(__file__), "data", "data-DAS-gfz2020wswf.npy"))
2 years ago
def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
b, a = butter(order, (lowcut, highcut), btype="bandpass", fs=fs)
y = lfilter(b, a, data, axis=0)
return y
2 years ago
@lru_cache
def triangular_taper(size: int, plateau: int):
if plateau > size:
raise ValueError("Plateau cannot be larger than size.")
if size % 2 or plateau % 2:
raise ValueError("Size and plateau have to be even.")
ramp_size = int((size - plateau) / 2)
ramp = num.linspace(0.0, 1.0, ramp_size)
window = num.ones(size)
window[:ramp_size] = ramp
window[size - ramp_size :] = ramp[::-1]
return window * window[:, num.newaxis]
def data_coherency(data):
ntraces = data.shape[0]
coherency = 0.0
for itr in range(ntraces - 1):
coherency += signal.coherence(data[itr], data[itr + 1])[1].mean()
return coherency / (ntraces - 1)
2 years ago
@timeit
2 years ago
def goldstein_filter(data, window_size: int = 32, overlap: int = 14, exponent=0.3):
if num.log2(window_size) % 1.0 or window_size < 4:
raise ValueError("window_size has to be pow(2) and > 4.")
if overlap > window_size / 2 - 1:
2 years ago
raise ValueError("Overlap is too large. Maximum overlap: window_size / 2 - 1.")
2 years ago
window_stride = window_size - overlap
window_non_overlap = window_size - 2 * overlap
npx_x, npx_y = data.shape
nwin_x = npx_x // window_stride
nwin_y = npx_y // window_stride
if nwin_x % 1 or nwin_y % 1:
raise ValueError("Padding does not match desired data shape")
filtered_data = num.zeros_like(data)
2 years ago
taper = triangular_taper(window_size, window_non_overlap)
2 years ago
for iwin_x in range(nwin_x):
px_x = iwin_x * window_stride
slice_x = slice(px_x, px_x + window_size)
for iwin_y in range(nwin_y):
px_y = iwin_y * window_stride
slice_y = slice(px_y, px_y + window_size)
window_data = data[slice_x, slice_y]
window_fft = num.fft.rfft2(window_data)
2 years ago
weights = num.abs(window_fft) ** exponent
2 years ago
# Optional
2 years ago
# weights = signal.medfilt2d(weights, kernel_size=3)
2 years ago
# window_coherency = data_coherency(window_data)
2 years ago
window_fft *= weights
2 years ago
window_filtered = num.fft.irfft2(window_fft)
2 years ago
taper_this = taper[: window_filtered.shape[0], : window_filtered.shape[1]]
2 years ago
filtered_data[
px_x : px_x + window_filtered.shape[0],
px_y : px_y + window_filtered.shape[1],
] += (
2 years ago
window_filtered * taper_this
2 years ago
)
return filtered_data
def normalize_diff(data, filtered_data):
2 years ago
return Normalize()(data) - Normalize()(filtered_data)
2 years ago
2 years ago
def plot_goldstein(data):
2 years ago
def r(data):
v = num.std(data)
return -v, v
2 years ago
from .rust import goldstein_filter as goldstein_filter_rust
goldstein_filter_rust = timeit(goldstein_filter_rust)
2 years ago
window_size = 32
overlap = 14
2 years ago
exponent = 0.5
2 years ago
2 years ago
imshow_kwargs = dict(aspect=5, cmap="viridis", vmin=-1.0, vmax=1.0)
2 years ago
2 years ago
fig, (ax1, ax3, ax4) = plt.subplots(1, 3, sharex=True, sharey=True)
ax1.imshow(data, **imshow_kwargs)
# image_python = ax2.imshow(data, **imshow_kwargs)
image_rust = ax3.imshow(data, **imshow_kwargs)
image_diff = ax4.imshow(data, **imshow_kwargs)
2 years ago
2 years ago
ax_slider = plt.axes((0.25, 0.05, 0.65, 0.03))
2 years ago
ax1.set_title("Data Input")
2 years ago
# ax2.set_title("Data Filtered (Python)")
2 years ago
ax3.set_title("Data Filtered (Rust)")
ax4.set_title("Data Residual")
2 years ago
exp_slider = Slider(
ax=ax_slider,
label="Exponent",
valmin=0,
valmax=1.0,
valinit=exponent,
orientation="horizontal",
)
def update(val):
global data
2 years ago
# data = data_bp
2 years ago
2 years ago
# data_filtered = goldstein_filter(
# data, window_size=window_size, exponent=val, overlap=overlap
# )
2 years ago
2 years ago
data_filtered = goldstein_filter_rust(
2 years ago
data, window_size=window_size, exponent=val, overlap=overlap
)
2 years ago
# image_python.set_data(data_filtered)
# image_python.set_norm(Normalize(*r(data_filtered)))
2 years ago
2 years ago
image_rust.set_data(data_filtered)
2 years ago
image_rust.set_norm(None)
2 years ago
2 years ago
norm_diff = normalize_diff(data, data_filtered)
image_diff.set_data(norm_diff)
2 years ago
image_diff.set_norm(None)
2 years ago
fig.canvas.draw_idle()
2 years ago
update(exponent)
2 years ago
exp_slider.on_changed(update)
plt.show()
def plot_taper():
window = triangular_taper(32, 4)
plt.imshow(window)
plt.show()
if __name__ == "__main__":
2 years ago
import argparse
import pathlib
parser = argparse.ArgumentParser()
parser.add_argument("npy_file", type=pathlib.Path)
args = parser.parse_args()
data = num.load(args.npy_file).astype(num.float32)
plot_goldstein(data)