Compare commits

...

8 Commits

  1. 32
      src/apps/grond.py
  2. 42
      src/core.py
  3. 16
      src/dataset.py
  4. 37
      src/targets/base.py
  5. 41
      src/targets/waveform/target.py

32
src/apps/grond.py

@ -807,17 +807,35 @@ def command_harvest(args):
'average misfit of all NEACH best in all chains, '
'3: harvesting is done on the global chain only, bootstrap '
'chains are excluded')
parser.add_option(
'--export-fits', dest='export_fits', default='',
help='additionally export details about the fit of individual '
'targets. "best" - export fits of best model, "mean" - '
'export fits of ensemble mean model, "ensemble" - export '
'fits of all models in harvest ensemble.')
parser, options, args = cl_parse('harvest', args, setup)
if len(args) != 1:
if len(args) < 1:
help_and_die(parser, 'no rundir')
run_path, = args
grond.harvest(
run_path,
force=options.force,
nbest=options.neach,
weed=options.weed)
export_fits = []
if options.export_fits.strip():
export_fits = [x.strip() for x in options.export_fits.split(',')]
for run_path in args:
try:
grond.harvest(
run_path,
force=options.force,
nbest=options.neach,
weed=options.weed,
export_fits=export_fits)
except grond.DirectoryAlreadyExists as e:
die(str(e) + '\n Use --force to overwrite.')
except grond.GrondError as e:
die(str(e))
def command_cluster(args):

42
src/core.py

@ -21,6 +21,8 @@ from .problems.base import Problem, load_problem_info_and_data, \
from .optimisers.base import BadProblem
from .targets.waveform.target import WaveformMisfitResult
from .targets.base import dump_misfit_result_collection, \
MisfitResultCollection, MisfitResult, MisfitResultError
from .meta import expand_template, GrondError, selected
from .environment import Environment
from .monitor import GrondMonitor
@ -65,7 +67,7 @@ def lock_rundir(rundir):
os.remove(statefn)
class DirectoryAlreadyExists(Exception):
class DirectoryAlreadyExists(GrondError):
pass
@ -146,7 +148,9 @@ def forward(env, show='filtered'):
trace.snuffle(all_trs, markers=markers, stations=list(stations.values()))
def harvest(rundir, problem=None, nbest=10, force=False, weed=0):
def harvest(
rundir, problem=None, nbest=10, force=False, weed=0,
export_fits=[]):
env = Environment([rundir])
optimiser = env.get_optimiser()
@ -166,7 +170,8 @@ def harvest(rundir, problem=None, nbest=10, force=False, weed=0):
if force:
shutil.rmtree(dumpdir)
else:
raise DirectoryAlreadyExists(dumpdir)
raise DirectoryAlreadyExists(
'Harvest directory already exists: %s' % dumpdir)
util.ensuredir(dumpdir)
@ -205,6 +210,36 @@ def harvest(rundir, problem=None, nbest=10, force=False, weed=0):
for i in ibests:
problem.dump_problem_data(dumpdir, xs[i], misfits[i, :, :])
if export_fits:
env.setup_modelling()
problem = env.get_problem()
history = env.get_history(subset='harvest')
for what in export_fits:
if what == 'best':
models = [history.get_best_model()]
elif what == 'mean':
models = [history.get_mean_model()]
elif what == 'ensemble':
models = history.models
else:
raise GrondError(
'Invalid option for harvest\'s export_fits argument: %s'
% what)
results = []
for x in models:
results.append([
(result if isinstance(result, MisfitResult)
else MisfitResultError(message=str(result))) for
result in problem.evaluate(x)])
emr = MisfitResultCollection(results=results)
dump_misfit_result_collection(
emr,
op.join(dumpdir, 'fits-%s.yaml' % what))
logger.info('Done harvesting problem "%s".' % problem.name)
@ -799,6 +834,7 @@ def export(
__all__ = '''
DirectoryAlreadyExists
forward
harvest
cluster

16
src/dataset.py

@ -60,6 +60,12 @@ class StationCorrection(Object):
factor = Float.T()
class WFTargetMisfit(Object):
codes = Tuple.T(4, String.T())
misfit = Float.T()
norm = Float.T()
weight = Float.T(optional=True)
def load_station_corrections(filename):
scs = load_all(filename=filename)
for sc in scs:
@ -68,8 +74,12 @@ def load_station_corrections(filename):
return scs
def dump_station_corrections(station_corrections, filename):
return dump_all(station_corrections, filename=filename)
def dump_station_corrections(station_corrections, filename=None, stream=None):
return dump_all(station_corrections, filename=filename, stream=stream)
def dump_wftarget_misfits(wftarget_misfits, filename=None, stream=None):
return dump_all(wftarget_misfits, filename=filename, stream=stream)
class Dataset(object):
@ -1157,6 +1167,8 @@ __all__ = '''
InvalidObject
NotFound
StationCorrection
WFTargetMisfit
load_station_corrections
dump_station_corrections
dump_wftarget_misfits
'''.split()

37
src/targets/base.py

@ -4,10 +4,10 @@ import numpy as num
from pyrocko import gf
from pyrocko.guts_array import Array
from pyrocko.guts import Object, Float, Dict
from pyrocko.guts import Object, Float, String, Dict, List, Choice, load, dump
from grond.analysers.base import AnalyserResult
from grond.meta import has_get_plot_classes
from grond.meta import has_get_plot_classes, GrondError
guts_prefix = 'grond'
@ -188,8 +188,41 @@ class MisfitTarget(Object):
return modelling_results[0]
class MisfitResultError(Object):
message = String.T()
class MisfitResultCollection(Object):
results = List.T(List.T(
Choice.T([MisfitResult.T(), MisfitResultError.T()])))
def dump_misfit_result_collection(misfit_result_collection, path):
dump(misfit_result_collection, filename=path)
def load_misfit_result_collection(path):
try:
obj = load(filename=path)
except OSError as e:
raise GrondError(
'Failed to read ensemble misfit results from file "%s" (%s)' % (
path, e))
if not isinstance(obj, MisfitResultCollection):
raise GrondError(
'File "%s" does not contain any misfit result collection.' % path)
return obj
__all__ = '''
TargetGroup
MisfitTarget
MisfitResult
MisfitResultError
dump_misfit_result_collection
load_misfit_result_collection
MisfitResultCollection
'''.split()

41
src/targets/waveform/target.py

@ -59,7 +59,7 @@ class StationDictStoreIDSelector(StoreIDSelector):
'''
mapping = Dict.T(
String.T(), String.T(),
String.T(), gf.StringID.T(),
help='Dictionary with station to store ID pairs, keys are NET.STA. '
"Add a fallback store ID under the key ``'others'``.")
@ -77,6 +77,29 @@ class StationDictStoreIDSelector(StoreIDSelector):
return store_id
class DepthRangeToStoreID(Object):
depth_min = Float.T()
depth_max = Float.T()
store_id = gf.StringID.T()
class StationDepthStoreIDSelector(StoreIDSelector):
'''
Store ID selector using a mapping from station depth range to store ID.
'''
depth_ranges = List.T(DepthRangeToStoreID.T())
def get_store_id(self, event, st, cha):
for r in self.depth_ranges:
if r.depth_min <= st.depth < r.depth_max:
return r.store_id
raise StoreIDSelectorError(
'No store ID found for station "%s.%s" at %g m depth.' % (
st.network, st.station, st.depth))
class DomainChoice(StringChoice):
choices = [
'time_domain',
@ -87,10 +110,6 @@ class DomainChoice(StringChoice):
'cc_max_norm']
class Trace(Object):
pass
class WaveformMisfitConfig(MisfitConfig):
quantity = gf.QuantityType.T(default='displacement')
fmin = Float.T(default=0.0, help='minimum frequency of bandpass filter')
@ -336,10 +355,10 @@ class WaveformMisfitResult(gf.Result, MisfitResult):
A number of different waveform or phase representations are possible.
'''
processed_obs = Trace.T(optional=True)
processed_syn = Trace.T(optional=True)
filtered_obs = Trace.T(optional=True)
filtered_syn = Trace.T(optional=True)
processed_obs = trace.Trace.T(optional=True)
processed_syn = trace.Trace.T(optional=True)
filtered_obs = trace.Trace.T(optional=True)
filtered_syn = trace.Trace.T(optional=True)
spectrum_obs = TraceSpectrum.T(optional=True)
spectrum_syn = TraceSpectrum.T(optional=True)
@ -347,7 +366,7 @@ class WaveformMisfitResult(gf.Result, MisfitResult):
tobs_shift = Float.T(optional=True)
tsyn_pick = Timestamp.T(optional=True)
tshift = Float.T(optional=True)
cc = Trace.T(optional=True)
cc = trace.Trace.T(optional=True)
piggyback_subresults = List.T(WaveformPiggybackSubresult.T())
@ -778,6 +797,8 @@ __all__ = '''
StoreIDSelector
Crust2StoreIDSelector
StationDictStoreIDSelector
DepthRangeToStoreID
StationDepthStoreIDSelector
WaveformTargetGroup
WaveformMisfitConfig
WaveformMisfitTarget

Loading…
Cancel
Save