refactor, better threading

develop
Marius Isken 3 years ago
parent debf6dc7eb
commit f31f2479e7

@ -18,8 +18,11 @@ idas_convert dump_config
### Example Config
```yaml
--- !idas.iDASConvertConfig
# Loading TDMS in parallel
batch_size: 1
# Loading TDMS in parallel and process
nthreads_loading: 1
nthreads_processing: 8
queue_size: 32
processing_batch_size: 8
# Threads used for downsampling the data
nthreads: 8

@ -63,7 +63,7 @@ def main():
converter = config.get_converter()
try:
converter.start(checkpt_file=checkpt_file)
converter.start(checkpt_file)
except Exception as e:
logger.exception(e)
raise e

@ -2,8 +2,9 @@ import os
import re
import logging
import threading
import queue
from time import time
from time import time, sleep
from glob import iglob
from datetime import timedelta, datetime
from itertools import repeat, chain
@ -19,7 +20,7 @@ from pyrocko.guts import Object, String, Int, List, Timestamp
from .plugin import PluginConfig, PLUGINS_AVAILABLE
from .meta import Path
from .utils import Signal
from .utils import Signal, sizeof_fmt
guts_prefix = 'idas'
@ -33,75 +34,6 @@ op = os.path
day = 3600. * 24
@dataclass
class Stats(object):
io_load_t: float = 0.
io_load_t_total: float = 0.
io_save_t: float = 0.
io_save_t_total: float = 0.
io_load_bytes: int = 0
io_load_bytes_total: int = 0
tprocessing: float = 0.
tprocessing_total: float = 0.
nfiles_total: int = 0
nfiles_processed: int = 0
time_processing: float = 0.
time_start: float = time()
processed_tmax: float = 0.
def new_io_load(self, t, bytes):
self.io_load_t = t
self.io_load_bytes = bytes
self.io_load_bytes_total += bytes
self.io_load_t_total += t
def new_io_tsave(self, t):
self.io_save_t = t
self.io_save_t_total += t
def new_tprocessing(self, t):
self.tprocessing = t
self.tprocessing_total += t
def finished_batch(self, nfiles):
self.nfiles_processed += nfiles
@property
def nfiles_remaining(self):
return self.nfiles_total - self.nfiles_processed
@property
def time_remaining(self):
proc_time = time() - self.time_start
s = self.nfiles_remaining*(proc_time/(self.nfiles_processed or 1))
return timedelta(seconds=s)
@property
def time_remaining_str(self):
return str(self.time_remaining)[:-7]
@property
def duration(self):
return timedelta(seconds=time() - self.time_start)
@property
def io_load_speed(self):
return (self.io_load_bytes / 1e6) / \
(self.io_load_t or 1.)
@property
def io_load_speed_avg(self):
return (self.io_load_bytes_total / 1e6) / \
(self.io_load_t_total or 1.)
@property
def processed_tmax_str(self):
return tts(self.processed_tmax)
def split(tr, time):
try:
return (tr.chop(tr.tmin, time, inplace=False),
@ -135,6 +67,45 @@ def detect_files(path):
return files
class LoadTDMSThread(threading.Thread):
def __init__(self, fn_queue, traces_queue):
super().__init__()
self.fn_queue = fn_queue
self.traces_queue = traces_queue
self.bytes_loaded = 0
self.time_loading = 0.
self.stop = threading.Event()
@property
def bytes_input_rate(self):
return self.bytes_loaded / (self.time_loading or 1.)
def run(self):
logger.info('Starting loading thread %s', self.name)
while not self.stop.is_set():
try:
ifn, fn = self.fn_queue.get(timeout=1.)
except queue.Empty:
continue
logger.debug(
'Loading %s (thread: %s)',
op.basename(fn), threading.get_ident())
t_start = time()
traces = io.load(fn, format='tdms_idas')
fsize = op.getsize(fn)
self.time_loading += time() - t_start
self.bytes_loaded += fsize
self.traces_queue.put((ifn, traces, fn, fsize))
self.fn_queue.task_done()
def stop(self):
self._stop = True
def process_data(args):
trace, deltat, tmin, tmax = args
@ -154,7 +125,201 @@ def process_data(args):
return trace
def load_idas(fn):
def get_traces_end(traces, overlap=1.):
trs_chopped = []
for tr in traces:
try:
trs_chopped.append(
tr.chop(tr.tmax - overlap, tr.tmax, inplace=False))
except trace.NoData:
return []
return trs_chopped
class ProcessingThread(threading.Thread):
def __init__(self, processing_queue, out_queue,
downsample_to=200.,
new_network_code='ID', new_channel_code='HSF',
channel_selection=None,
tmin=None, tmax=None, nthreads=12, batch_size=12):
super().__init__()
self.in_traces = processing_queue
self.out_traces = out_queue
self.downsample_to = downsample_to
self.channel_selection = channel_selection
self.new_network_code = new_network_code
self.new_channel_code = new_channel_code
self.tmin = tmin
self.tmax = tmax
self.deltat = 1./self.downsample_to
self.batch_size = batch_size
self.executor = ThreadPoolExecutor(max_workers=nthreads)
self.ifn = 0
self.nfiles = 0
self.bytes_in = 0
self.time_processing = 0.
self.stop = threading.Event()
@property
def processing_file_rate(self):
return (self.ifn + 1) / (self.time_processing or 1.)
@property
def processing_rate(self):
return self.bytes_in / (self.time_processing or 1.)
def get_new_traces(self):
new_trs = []
fns = []
while len(new_trs) < self.batch_size:
try:
this_ifn, trs, fn, size = self.in_traces.get(timeout=10.)
if this_ifn != self.ifn:
self.in_traces.put((this_ifn, trs, fn, size))
self.in_traces.task_done()
sleep(.1)
continue
new_trs.append(trs)
fns.append(fn)
self.ifn += 1
self.bytes_in += size
except queue.Empty:
break
return new_trs, fns
def run(self):
logger.info('Starting processing thread')
trs_overlap = []
batch_tmin = self.tmin
while not self.stop.is_set():
new_trs, new_fns = self.get_new_traces()
nnew_trs = len(new_trs)
new_trs = list(chain(*new_trs))
if self.channel_selection:
new_trs = [tr for tr in new_trs
if self.channel_selection.match(tr.station)]
if self.stop.is_set() and not new_trs:
break
elif not new_trs:
continue
logger.debug('Start processing %d trace groups', nnew_trs)
t_start = time()
trs_degapped = trace.degapper(
sorted(new_trs + trs_overlap, key=lambda tr: tr.full_id))
if trs_overlap and batch_tmin \
and len(trs_degapped) != len(trs_overlap):
logger.warning('Gap detected at %s', tts(batch_tmin))
trs_degapped = trace.degapper(
sorted(new_trs, key=lambda tr: tr.full_id))
trs_overlap = get_traces_end(trs_degapped)
trs_ds = list(self.executor.map(
process_data,
zip(trs_degapped,
repeat(1./self.downsample_to),
repeat(batch_tmin), repeat(self.tmax))))
trs_ds = list(filter(lambda tr: tr is not None, trs_ds))
if not trs_ds:
continue
for tr in trs_ds:
tr.set_network(self.new_network_code)
tr.set_channel(self.new_channel_code)
batch_tmax = max(tr.tmax for tr in trs_ds)
batch_tmin = min(tr.tmin for tr in trs_ds)
# Split traces at day break
dt_min = datetime.fromtimestamp(batch_tmin)
dt_max = datetime.fromtimestamp(batch_tmax)
if dt_min.date() != dt_max.date():
dt_split = datetime.combine(dt_max.date(), datetime.min.time())
tsplit = dt_split.timestamp()
trs_ds = list(chain(*(split(tr, tsplit) for tr in trs_ds)))
batch_tmax += self.deltat
batch_tmin = batch_tmax
self.out_traces.put((batch_tmax, trs_ds, new_fns))
for _ in range(nnew_trs):
self.in_traces.task_done()
self.time_per_batch = time() - t_start
self.time_processing += self.time_per_batch
logger.info(
'Processed %d groups in %.2f s', nnew_trs, self.time_per_batch)
self.executor.shutdown()
logger.info('Shutting down processing thread')
class SaveMSeedThread(threading.Thread):
def __init__(self, in_queue, outpath,
record_length=4096, checkpt_file=None):
super().__init__()
self.queue = in_queue
self.outpath = outpath
self.record_length = record_length
self.checkpt_file = checkpt_file
self.tmax = 0.
self.processed_files = queue.Queue()
def set_checkpt_file(self, path):
self.checkpt_file = path
def get_tmax(self):
return self.tmax
def run(self):
logger.info('Starting MiniSeed saving thread')
while True:
tmax, traces, fns = self.queue.get()
if traces is False:
logger.debug('Shutting down saving thread')
self.queue.task_done()
return
t_start = time()
io.save(
traces, self.outpath,
format='mseed',
record_length=self.record_length,
append=True)
if self.checkpt_file is not None:
with open(self.checkpt_file, 'w') as f:
f.write(str(tmax))
self.tmax = tmax
self.processed_files.put(fns)
self.queue.task_done()
logger.debug(
'Saved %d traces from queue in %.1f s',
len(traces), time() - t_start)
def load_idas_thread(fn):
logger.debug(
'Loading %s (thread: %s)', op.basename(fn), threading.get_ident())
return io.load(fn, format='tdms_idas')
@ -166,11 +331,12 @@ class iDASConvert(object):
self, paths, outpath,
downsample_to=200., record_length=4096,
new_network_code='ID', new_channel_code='HSF',
channel_selection=None,
tmin=None, tmax=None,
nthreads=8, batch_size=1, plugins=[]):
channel_selection=None, tmin=None, tmax=None,
nthreads_loading=8, nthreads_processing=24,
queue_size=32, processing_batch_size=8, plugins=[]):
if tmin is not None and tmax is not None:
assert tmin < tmax
logger.info('Detecting files...')
files = []
for path in paths:
@ -199,193 +365,194 @@ class iDASConvert(object):
logger.info('Sorting %d files', len(files))
self.files = sorted(files, key=lambda f: op.basename(f))
self.bytes_total = sum([op.getsize(f) for f in self.files])
logger.info('Got %s of data', sizeof_fmt(self.bytes_total))
self.files_all = self.files.copy()
self.nfiles = len(self.files_all)
self.stats = Stats(nfiles_total=self.nfiles)
self.nfiles_processed = 0
self.t_start = 0.
self.processing_batch_size = processing_batch_size
self.outpath = outpath
self.channel_selection = None if not channel_selection \
channel_selection = None if not channel_selection \
else re.compile(channel_selection)
self.new_network_code = new_network_code
self.new_channel_code = new_channel_code
self.downsample_to = downsample_to
self.record_length = record_length
self.tmin = tmin
self.tmax = tmax
self.processed_tmax = 0.
self.nthreads = nthreads
self.batch_size = batch_size
self.before_batch_load = Signal(self)
self.finished_batch = Signal(self)
self.before_file_read = Signal(self)
self.new_traces_converted = Signal(self)
self.finished_batch = Signal(self)
self.finished = Signal(self)
self.load_fn_queue = queue.PriorityQueue(maxsize=queue_size)
self.processing_queue = queue.PriorityQueue(maxsize=queue_size)
self.save_queue = queue.PriorityQueue()
self.plugins = plugins
for plugin in self.plugins:
logger.info('Activating plugin %s', plugin.__class__.__name__)
plugin.set_parent(self)
self._trs_prev = []
self._tmax_prev = None
# Starting worker threads
self.load_threads = []
for ithread in range(nthreads_loading):
thread = LoadTDMSThread(self.load_fn_queue, self.processing_queue)
thread.name = 'LoadTDMS-%02d' % ithread
thread.start()
self.load_threads.append(thread)
self.processing_thread = ProcessingThread(
self.processing_queue, self.save_queue,
downsample_to, new_network_code, new_channel_code,
channel_selection, tmin, tmax,
nthreads_processing, processing_batch_size)
self.processing_thread.start()
self.save_thread = SaveMSeedThread(
self.save_queue, self.outpath, record_length)
self.save_thread.start()
@property
def nfiles_left(self):
return len(self.nfiles)
def start(self, checkpt_file=None):
logger.info('Starting conversion of %d files', self.nfiles)
stats = self.stats
t_start = time()
files = self.files
trs_overlap = None
tmax_prev = None
while self.files:
load_files = []
self.before_batch_load.dispatch(self.files)
while files and len(load_files) < self.batch_size:
fn = self.files.pop(0)
self.before_file_read.dispatch(fn)
load_files.append(fn)
with open(fn, 'rb') as f:
if not detect_tdms(f.read(512)):
logger.warning('Not a tdms file %s', fn)
continue
batch_tmin = tmax_prev
if self.tmin is not None:
batch_tmin = max(self.tmin, batch_tmin or -1.)
traces, batch_tmin, batch_tmax, trs_overlap = self.convert_files(
load_files,
tmin=batch_tmin,
tmax=self.tmax,
overlap=trs_overlap)
def load_queue_size(self):
return self.load_fn_queue.qsize()
self.finished_batch.dispatch(load_files)
stats.finished_batch(len(load_files))
tmax_prev = batch_tmax + 1./self.downsample_to
stats.processed_tmax = tmax_prev
if checkpt_file:
with open(checkpt_file, 'w') as f:
f.write(str(stats.processed_tmax))
logger.info(
'Processed {s.nfiles_processed}/{s.nfiles_total} files,'
' head at {s.processed_tmax_str}.'
' DS: {s.tprocessing:.2f},'
' IO: {s.io_load_t:.2f}/{s.io_save_t:.2f}'
' [load: {s.io_load_speed:.2f} MB/s].'
' Estimated time rmaining: {s.time_remaining_str}'.format(
s=stats))
self.finished.dispatch()
logger.info('Finished. Processed %d files in %.2f s',
stats.nfiles_processed, time() - t_start)
def convert_files(self, files, tmin=None, tmax=None, overlap=False):
nfiles = len(files)
stats = self.stats
max_workers = nfiles if nfiles < self.batch_size else self.batch_size
@property
def process_queue_size(self):
return self.processing_queue.qsize()
t_start = time()
nbytes = sum(op.getsize(fn) for fn in files)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
trs_all = list(chain(*executor.map(load_idas, files)))
@property
def save_queue_size(self):
return self.save_queue.qsize()
if self.channel_selection:
trs_all = [tr for tr in trs_all
if self.channel_selection.match(tr.station)]
if not trs_all:
raise TypeError('Did not load any traces!')
@property
def nfiles_remaining(self):
return len(self.files)
stats.new_io_load(time() - t_start, nbytes)
@property
def bytes_loaded(self):
return sum(thr.bytes_loaded for thr in self.load_threads)
trs = sorted(trs_all + (overlap or []), key=lambda tr: tr.full_id)
trs_degapped = trace.degapper(trs)
if (overlap and tmin) and len(trs_degapped) != len(overlap):
logger.warning('Gap detected at %s', tts(tmin))
trs_degapped = trace.degapper(
sorted(trs_all, key=lambda tr: tr.full_id))
@property
def bytes_remaining(self):
return self.bytes_total - self.bytes_loaded
trs_overlap = self.get_traces_end(trs_degapped)
@property
def bytes_processing_rate(self):
return self.bytes_loaded / self.duration
t_start = time()
with ThreadPoolExecutor(max_workers=self.nthreads) as executor:
trs_ds = list(executor.map(
process_data,
zip(trs_degapped,
repeat(1./self.downsample_to),
repeat(tmin), repeat(tmax))))
@property
def bytes_input_rate(self):
return sum(t.bytes_input_rate for t in self.load_threads)
trs_ds = list(filter(lambda tr: tr is not None, trs_ds))
@property
def processing_rate(self):
return self.processing_thread.processing_rate
if not trs_ds:
return [], None, None, trs_overlap
@property
def duration(self):
return time() - self.t_start
for tr in trs_degapped:
tr.set_network(self.new_network_code)
tr.set_channel(self.new_channel_code)
@property
def time_remaining(self):
if not self.bytes_processing_rate:
return timedelta(seconds=0.)
s = self.bytes_remaining / self.bytes_processing_rate
return timedelta(seconds=s)
batch_tmax = max(tr.tmax for tr in trs_ds)
batch_tmin = min(tr.tmin for tr in trs_ds)
@property
def time_remaining_str(self):
return str(self.time_remaining)[:-7]
stats.new_tprocessing(time() - t_start)
@property
def time_head(self):
return self.save_thread.get_tmax()
self.new_traces_converted.dispatch(trs_ds)
def start(self, checkpt_file=None):
logger.info('Starting conversion of %d files', self.nfiles)
ifn = 0
self.save_thread.set_checkpt_file(checkpt_file)
t_start = time()
self.before_batch_load.dispatch(self.files)
self.t_start = time()
# Split traces at day break
dt_min = datetime.fromtimestamp(batch_tmin)
dt_max = datetime.fromtimestamp(batch_tmax)
while self.files:
if ifn % self.processing_batch_size == 0:
self.before_batch_load.dispatch(self.files)
fn = self.files.pop(0)
self.before_file_read.dispatch(fn)
with open(fn, 'rb') as f:
if not detect_tdms(f.read(512)):
logger.warning('Not a tdms file %s', fn)
continue
self.load_fn_queue.put((ifn, fn))
self.check_processed()
ifn += 1
logger.debug('Joining load queue')
self.load_fn_queue.join()
for thread in self.load_threads:
thread.stop.set()
logger.debug('Joined load queue')
logger.debug('Joining processing queue')
self.processing_queue.join()
self.processing_thread.stop.set()
logger.debug('Joined processing queue')
# Ensure it is the last element
self.save_queue.put((time(), False))
self.save_queue.join()
logger.debug('Joining save trace queue')
self.check_processed()
self.finished.dispatch()
if dt_min.date() != dt_max.date():
dt_split = datetime.combine(dt_max.date(), datetime.min.time())
tsplit = dt_split.timestamp()
trs_ds = list(chain(*(split(tr, tsplit) for tr in trs_ds)))
def check_processed(self):
proc_fns_queue = self.save_thread.processed_files
if proc_fns_queue.empty():
return
io.save(
trs_ds, self.outpath,
format='mseed',
record_length=self.record_length,
append=True)
stats.new_io_tsave(time() - t_start)
finished_fns = []
while not proc_fns_queue.empty():
finished_fns.extend(proc_fns_queue.get_nowait())
return trs_ds, batch_tmin, batch_tmax, trs_overlap
self.nfiles_processed += len(finished_fns)
self.finished_batch.dispatch(finished_fns)
def get_traces_end(self, traces, overlap=1.):
trs_chopped = []
for tr in traces:
try:
trs_chopped.append(
tr.chop(tr.tmax - overlap, tr.tmax, inplace=False))
except trace.NoData:
return []
logger.info(self.get_status())
return trs_chopped
def get_status(self):
s = self
return (
f'Processed {s.nfiles_processed}/{s.nfiles} files'
f' ({sizeof_fmt(s.bytes_loaded)}/{sizeof_fmt(s.bytes_total)}'
f' @ {s.bytes_processing_rate/1e6:.1f} MB/s,'
f' In {s.bytes_input_rate/1e6:.1f} MB/s,'
f' Proc {s.processing_rate/1e6:.1f} MB/s).'
f' Head is at {tts(s.time_head)}.'
f' Queues'
f' L:{s.load_queue_size}'
f' P:{s.process_queue_size}'
f' S:{s.save_queue_size}.'
f' Estimated time remaining {s.time_remaining_str}.')
class iDASConvertConfig(Object):
batch_size = Int.T(
nthreads_loading = Int.T(
default=1,
help='Number of parallel loaded TDMS files and processed at once.')
nthreads = Int.T(
nthreads_processing = Int.T(
default=8,
help='Number of threads for processing data.')
queue_size = Int.T(
default=32,
help='Size of the queue holding loaded traces.')
processing_batch_size = Int.T(
default=8,
help='Number of traces processed at once.')
paths = List.T(
Path.T(),
default=[os.getcwd()],
@ -435,4 +602,6 @@ class iDASConvertConfig(Object):
self.new_network_code, self.new_channel_code,
self.channel_selection,
self.tmin, self.tmax,
self.nthreads, self.batch_size, plugins)
self.nthreads_loading, self.nthreads_processing,
self.queue_size, self.processing_batch_size,
plugins)

@ -3,6 +3,7 @@ from time import time
from datetime import timedelta
from pyrocko.guts import String, Float
from pyrocko.util import tts
from .plugin import Plugin, PluginConfig, register_plugin
from .utils import sizeof_fmt
@ -21,7 +22,7 @@ class TelegramBot(Plugin):
self.status_interval = status_interval
self.started = time()
self._next_status = self.started + status_interval
self._next_status = self.started
self.bot = telebot.TeleBot(self.token)
@ -66,28 +67,19 @@ class TelegramBot(Plugin):
logger.exception(e)
def send_finished(self, *args):
s = self.parent.stats
p = self.parent
duration = str(timedelta(seconds=p.duration))[:-7]
self.send_message(
'Finished processing {s.nfiles_processed} files'
' ({size_processed}) in {s.duration}.'.format(
s=s,
size_processed=sizeof_fmt(s.io_load_bytes_total)))
f'Finished processing {p.nfiles_processed} files'
f' ({sizeof_fmt(p.bytes_loaded)}) in {duration}.')
def send_status(self, *args):
if time() < self._next_status:
return
logger.debug('sending status message')
stats = self.parent.stats
self.send_message(
'Processed {s.nfiles_processed}/{s.nfiles_total} files '
' ({size_processed} @ {s.io_load_speed_avg:.1f} MB/s).'
' Head is at {s.processed_tmax_str:.19}.'
' Estimated time remaining {s.time_remaining_str}. '.format(
s=stats,
size_processed=sizeof_fmt(stats.io_load_bytes_total)))
self._next_status += self.status_interval
self.send_message(self.parent.get_status())
self._next_status = time() + self.status_interval
def __del__(self):
dt = timedelta(seconds=time() - self.started)

Loading…
Cancel
Save