Browse Source

Added TerminalListener // Cleaned HighScoreSolver

whitelist
miili 3 years ago
parent
commit
d04383de09
9 changed files with 308 additions and 58 deletions
  1. +2
    -1
      setup.py
  2. +2
    -0
      src/core.py
  3. +2
    -0
      src/listeners/__init__.py
  4. +13
    -0
      src/listeners/base.py
  5. +131
    -0
      src/listeners/curses.py
  6. +74
    -0
      src/listeners/terminal.py
  7. +5
    -5
      src/meta.py
  8. +58
    -2
      src/solvers/base.py
  9. +21
    -50
      src/solvers/highscore.py

+ 2
- 1
setup.py View File

@ -15,7 +15,8 @@ setup(
version='0.1',
author='Sebastian Heimann',
author_email='sebastian.heimann@gfz-potsdam.de',
packages=['grond', 'grond.baraddur', 'grond.problems', 'grond.solvers', 'grond.analysers'],
packages=['grond', 'grond.baraddur', 'grond.problems', 'grond.solvers',
'grond.analysers', 'grond.listeners'],
scripts=['apps/grond'],
package_dir={'grond': 'src'},
package_data={'grond': ['baraddur/templates/*.html',


+ 2
- 0
src/core.py View File

@ -15,6 +15,7 @@ from .dataset import DatasetConfig, NotFound
from .problems.base import ProblemConfig, Problem
from .solvers.base import SolverConfig
from .analysers.base import AnalyserConfig
from .listeners import TerminalListener
from .targets import TargetConfig
from .meta import Path, HasPaths, expand_template, xjoin, GrondError, Notifier
@ -704,6 +705,7 @@ def process_event(ievent, g_data_id):
'start %i / %i' % (ievent+1, nevents))
notifier = Notifier()
notifier.add_listener(TerminalListener())
analyser = config.analyser_config.get_analyser()
analyser.analyse(problem, notifier=notifier)


+ 2
- 0
src/listeners/__init__.py View File

@ -0,0 +1,2 @@
from .curses import CursesListener # noqa
from .terminal import TerminalListener # noqa

+ 13
- 0
src/listeners/base.py View File

@ -0,0 +1,13 @@
class Listener(object):
def progress_start(self, name, niter):
raise NotImplementedError()
def progress_finish(self, name):
raise NotImplementedError()
def progress_update(self, name, iiter):
raise NotImplementedError()
def state(self, state):
raise NotImplementedError()

+ 131
- 0
src/listeners/curses.py View File

@ -0,0 +1,131 @@
import curses
class State(object):
iiter = 0
niter = 0
iter_sec = 0.
problem_name = ''
parameter_names = []
column_names = []
values = []
text = ''
class _CursesPad(object):
def __init__(self, pad):
self.pad = pad
self.rows, self.cols = self.pad.getyx()
def resize_pad(self):
return
self.pad.resize(self.rows, self.cols)
class CursesListener(object):
class ParameterTable(_CursesPad):
value_fmt = '{0:8.4g}'
column_padding = 2
def update(self, state):
pad = self.pad
pad.clear()
if not state:
return
parameter_names = ['Parameters'] + state.parameters
col = 0
for icol in xrange(len(state.values) + 1):
row = 0
if icol == 0:
col_width = max([len(p) for p in parameter_names])
for name in parameter_names:
pad.addstr(
row, col,
'{:<{width}}'.format(
name, width=col_width),
curses.A_BOLD)
row += 1
else:
igroup = icol - 1
col_heading = state.column_names[igroup]
col_width = max(
len(col_heading),
len(self.value_fmt.format(0.)) + self.column_padding)
pad.addstr(row, col,
'{:>{width}}'.format(
col_heading, width=col_width),
curses.A_BOLD)
for iv, v in enumerate(state.values[igroup]):
row += 1
vstr = ' ' * self.column_padding +\
self.value_fmt.format(v)
pad.addstr(row, col, vstr)
col += col_width
self.rows = row
self.resize_pad()
pad.noutrefresh()
class Footer(_CursesPad):
def update(self, state):
pad = self.pad
pad.clear()
if not state:
return
pad.addstr(0, 0, 'Performance:')
pad.addstr(0, 14, '%.1f iter/s' % state.iter_sec)
pad.addstr(1, 0, state.text)
self.rows = 3
self.resize_pad()
pad.noutrefresh()
class Header(_CursesPad):
def update(self, state):
pad = self.pad
pad.clear()
if not state:
return
pad.addstr(0, 0, 'Problem Name:')
pad.addstr(0, 14, state.problem_name,
curses.A_BOLD)
pad.addstr(1, 0, 'Iteration:')
pad.addstr(1, 14, '%d / %d' % (state.iiter, state.niter),
curses.A_BOLD)
self.rows = 3
self.resize_pad()
pad.noutrefresh()
def __init__(self):
self.scr = None
self.state = None
curses.wrapper(self.set_screen)
self.header_pad = self.Header(self.scr.subpad(3, 100, 0, 0))
self.parameter_pad = self.ParameterTable(self.scr.subpad(3, 0))
self.footer_pad = self.Footer(self.scr.subpad(3, 100, 5, 0))
def set_screen(self, scr):
self.scr = scr
def set_state(self, state):
self.state = state
self.parameter_pad.update(state)
self.header_pad.update(state)
self.footer_pad.update(state)
self.footer_pad.pad.mvwin(
self.parameter_pad.rows + self.parameter_pad.pad.getparyx()[0] + 2,
0)
self.scr.refresh()

+ 74
- 0
src/listeners/terminal.py View File

@ -0,0 +1,74 @@
import progressbar as pbar
from .base import Listener
class color:
PURPLE = '\033[95m'
CYAN = '\033[96m'
DARKCYAN = '\033[36m'
BLUE = '\033[94m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
END = '\033[0m'
class TerminalListener(Listener):
col_width = 15
row_name = color.BOLD + '{:<{col_param_width}s}' + color.END
parameter_fmt = '{:>{col_width}{type}}'
def __init__(self):
self.current_state = None
self.pbars = {}
def progress_start(self, name, niter):
self.pbars[name] = pbar.start(name, niter)
def progress_update(self, name, iiter):
self.pbars[name].update(iiter)
def progress_finish(self, name):
self.pbars[name].finish()
def state(self, state):
lines = []
self.current_state = state
def l(t):
lines.append(t)
out_ln = self.row_name +\
''.join([self.parameter_fmt] * len(state.parameter_values))
col_param_width = max([len(p) for p in state.parameter_names]) + 2
l('Problem name: {s.problem_name}'
'\t({s.runtime:s} - remaining {s.runtime_remaining})'
.format(s=state))
l('Iteration {s.iiter} / {s.niter} ({s.iter_per_second:.1f} iter/s)'
.format(s=state))
l(out_ln.format(
*['Parameter'] + state.column_names,
col_param_width=col_param_width,
col_width=self.col_width,
type='s'))
for ip, parameter_name in enumerate(state.parameter_names):
l(out_ln.format(
parameter_name,
*[v[ip] for v in state.parameter_values],
col_param_width=col_param_width,
col_width=self.col_width,
type='.4g'))
l(state.extra_text.format(
col_param_width=col_param_width,
col_width=self.col_width,))
lines[0:0] = ['\033[2J']
l('')
print '\n'.join(lines)

+ 5
- 5
src/meta.py View File

@ -198,9 +198,9 @@ class Notifier(object):
def emit(self, signal_name, *args, **kwargs):
for listener in self._listeners:
try:
getattr(listener, signal_name)(*args, **kwargs)
except AttributeError:
if not hasattr(listener, signal_name):
logger.warn(
'signal name %s not implemented in listener' % signal_name)
'signal name \'%s\' not implemented in listener %s'
% (signal_name, type(listener)))
continue
getattr(listener, signal_name)(*args, **kwargs)

+ 58
- 2
src/solvers/base.py View File

@ -1,4 +1,7 @@
import logging
import time
import numpy as num
from datetime import timedelta
from pyrocko.guts import Object
@ -7,10 +10,62 @@ guts_prefix = 'grond'
logger = logging.getLogger('grond.solver')
class RingBuffer(num.ndarray):
def __init__(self, *args, **kwargs):
num.ndarray.__init__(self, *args, **kwargs)
self.fill(0.)
self.pos = 0
def put(self, value):
self[self.pos] = value
self.pos += 1
self.pos %= self.size
class SolverState(object):
problem_name = ''
parameter_names = []
parameter_values = []
column_names = []
extra_text = ''
niter = 0
_iiter = 0
iter_per_second = 0.
_iter_buffer = RingBuffer(20)
starttime = time.time()
_last_update = time.time()
@property
def iiter(self):
return self._iiter
@iiter.setter
def iiter(self, value):
dt = time.time() - self._last_update
self._iter_buffer.put(float((value - self._iiter) / dt))
self.iter_per_second = float(self._iter_buffer.mean())
self._iiter = value
self._last_update = time.time()
@property
def runtime(self):
return timedelta(seconds=time.time() - self.starttime)
@property
def runtime_remaining(self):
if self.iter_per_second == 0.:
return timedelta()
return timedelta(seconds=(self.niter - self.iiter)
/ self.iter_per_second)
class Solver(object):
def solve(
self, problem, rundir=None, status=(), plot=None, xs_inject=None):
state = SolverState()
def solve(
self, problem, rundir=None, status=(), plot=None, xs_inject=None,
notifier=None):
raise NotImplemented()
@ -22,5 +77,6 @@ class SolverConfig(Object):
__all__ = '''
Solver
SolverState
SolverConfig
'''.split()

+ 21
- 50
src/solvers/highscore.py View File

@ -69,7 +69,8 @@ def solve(problem,
xs_inject=None,
status=(),
plot=None,
notifier=None):
notifier=None,
state=None):
xbounds = num.array(problem.get_parameter_bounds(), dtype=num.float)
npar = problem.nparameters
@ -99,6 +100,12 @@ def solve(problem,
accept_hist = num.zeros(niter, dtype=num.int)
pnames = problem.parameter_names
state.problem_name = problem.name
state.column_names = ['B mean', 'B std',
'G mean', 'G std', 'G best']
state.parameter_names = problem.parameter_names + ['Misfit']
state.niter = niter
if plot:
plot.start(problem)
@ -300,22 +307,12 @@ def solve(problem,
accept_sum += accept
accept_hist[iiter] = num.sum(accept)
lines = []
if 'state' in status:
lines.append('%s, %i' % (problem.name, iiter))
lines.append(''.join('-X'[int(acc)] for acc in accept))
xhist[iiter, :] = x
bxs = xhist[chains_i[:, :nlinks].ravel(), :]
gxs = xhist[chains_i[0, :nlinks], :]
gms = chains_m[0, :nlinks]
col_width = 15
col_param_width = max([len(p) for p in problem.parameter_names])+2
console_output = '{:<{col_param_width}s}'
console_output += ''.join(['{:>{col_width}{type}}'] * 5)
if nlinks > (nlinks_cap-1)/2:
# mean and std of all bootstrap ensembles together
mbx = num.mean(bxs, axis=0)
@ -353,52 +350,25 @@ def solve(problem,
else:
assert False, 'invalid standard_deviation_estimator choice'
if 'state' in status:
lines.append(
console_output.format(
'parameter', 'B mean', 'B std', 'G mean', 'G std',
'G best',
col_param_width=col_param_width,
col_width=col_width,
type='s'))
for (pname, mbv, sbv, mgv, sgv, bgv) in zip(
pnames, mbx, sbx, mgx, sgx, bgx):
lines.append(
console_output.format(
pname, mbv, sbv, mgv, sgv, bgv,
col_param_width=col_param_width,
col_width=col_width,
type='.4g'))
lines.append(
console_output.format(
'misfit', '', '',
'%.4g' % num.mean(gms),
'%.4g' % num.std(gms),
'%.4g' % num.min(gms),
col_param_width=col_param_width,
col_width=col_width,
type='s'))
state.parameter_values = [
num.append(mbx, num.nan),
num.append(sbx, num.nan),
num.append(mgx, num.mean(gms)),
num.append(sgx, num.std(gms)),
num.append(bgx, num.min(gms))]
state.iiter = iiter + 1
state.extra_text =\
'Phase: %s (factor %d); ntries %d, ntries_preconstrain %d'\
% (phase, factor, ntries_sample, ntries_preconstrain)
if 'state' in status:
lines.append(
console_output.format(
'iteration', iiter+1, '(%s, %g)' % (phase, factor),
ntries_sample, ntries_preconstrain, '',
col_param_width=col_param_width,
col_width=col_width,
type=''))
notifier.emit('state', state)
if 'matrix' in status:
lines = []
matrix = (chains_i[:, :30] % 94 + 32).T
for row in matrix[::-1]:
lines.append(''.join(chr(xxx) for xxx in row))
if status:
lines[0:0] = ['\033[2J']
lines.append('')
print '\n'.join(lines)
if plot and plot.want_to_update(iiter):
@ -433,6 +403,7 @@ class HighScoreSolver(Solver):
plot=plot,
xs_inject=xs_inject,
notifier=notifier,
state=self.state,
**self._kwargs)


Loading…
Cancel
Save