A seismology toolkit for Python
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
3.2 KiB

  1. # python 2/3
  2. from __future__ import division, print_function, absolute_import
  3. import sys
  4. import unittest
  5. import os
  6. import glob
  7. try:
  8. from urllib2 import HTTPError
  9. except ImportError:
  10. from urllib.error import HTTPError
  11. from .. import common
  12. from pyrocko import ExternalProgramMissing
  13. from pyrocko import util
  14. from pyrocko import example
  15. from pyrocko import pile
  16. from pyrocko.plot import gmtpy
  17. from pyrocko.gui import snuffler
  18. from pyrocko.dataset import topo
  19. op = os.path
  20. test_dir = op.dirname(op.abspath(__file__))
  21. skip_examples = [
  22. 'trace_restitution_dseed.py',
  23. 'gf_forward_viscoelastic.py'
  24. ]
  25. def tutorial_run_dir():
  26. return op.join(test_dir, '..', 'example_run_dir')
  27. def noop(*args, **kwargs):
  28. pass
  29. class ExamplesTestCase(unittest.TestCase):
  30. @classmethod
  31. def setUpClass(cls):
  32. from matplotlib import pyplot as plt
  33. cls.cwd = os.getcwd()
  34. cls.run_dir = tutorial_run_dir()
  35. util.ensuredir(cls.run_dir)
  36. cls.dn = open(os.devnull, 'w')
  37. sys.stdout = cls.dn
  38. os.chdir(cls.run_dir)
  39. plt.show_orig_testex = plt.show
  40. plt.show = noop
  41. snuffler.snuffle_orig_testex = snuffler.snuffle
  42. snuffler.snuffle = noop
  43. cls._show_progress_force_off_orig = pile.show_progress_force_off
  44. pile.show_progress_force_off = True
  45. @classmethod
  46. def tearDownClass(cls):
  47. from matplotlib import pyplot as plt
  48. cls.dn.close()
  49. sys.stdout = sys.__stdout__
  50. os.chdir(cls.cwd)
  51. snuffler.snuffle = snuffler.snuffle_orig_testex
  52. plt.show = plt.show_orig_testex
  53. pile.show_progress_force_off = cls._show_progress_force_off_orig
  54. example_files = [fn for fn in glob.glob(op.join(test_dir, 'examples', '*.py'))
  55. if os.path.basename(fn) not in skip_examples]
  56. def _make_function(test_name, fn):
  57. def f(self):
  58. imp = imp2 = None
  59. try:
  60. import imp
  61. except ImportError:
  62. import importlib.machinery as imp2
  63. try:
  64. if imp:
  65. imp.load_source(test_name, fn)
  66. else:
  67. imp2.SourceFileLoader(test_dir, fn)
  68. except example.util.DownloadError:
  69. raise unittest.SkipTest('could not download required data file')
  70. except HTTPError as e:
  71. raise unittest.SkipTest('skipped due to %s: "%s"' % (
  72. e.__class__.__name__, str(e)))
  73. except ExternalProgramMissing as e:
  74. raise unittest.SkipTest(str(e))
  75. except ImportError as e:
  76. raise unittest.SkipTest(str(e))
  77. except topo.AuthenticationRequired:
  78. raise unittest.SkipTest(
  79. 'cannot download topo data (no auth credentials)')
  80. except gmtpy.GMTInstallationProblem:
  81. raise unittest.SkipTest('GMT not installed or not usable')
  82. except Exception as e:
  83. raise e
  84. f.__name__ = 'test_example_' + test_name
  85. return f
  86. for fn in sorted(example_files):
  87. test_name = op.splitext(op.split(fn)[-1])[0]
  88. setattr(
  89. ExamplesTestCase,
  90. 'test_example_' + test_name,
  91. _make_function(test_name, fn))
  92. if __name__ == '__main__':
  93. util.setup_logging('test_examples', 'warning')
  94. common.matplotlib_use_agg()
  95. unittest.main()