Source code for pyro.util.compare

#!/usr/bin/env python3


import sys

import numpy as np

import pyro.util.io_pyro as io

usage = """
      usage: ./compare.py file1 file2 (rtol)

      where rtol is an (optional) relative tolerance parameter to use when
      comparing the data
"""

errors = {"gridbad": "grids don't agree",
          "namesbad": "variable lists don't agree",
          "varerr": "one or more variables don't agree"}


[docs] def compare(data1, data2, rtol=1.e-12): """ given two CellCenterData2d objects, compare the data, zone-by-zone and output any errors Parameters ---------- data1, data2 : CellCenterData2d object Two data grids to compare rtol : float relative tolerance to use to compare grids """ # compare the grids if not data1.grid == data2.grid: return "gridbad" # compare the data if not sorted(data1.names) == sorted(data2.names): return "namesbad" print(" ") print("variable comparisons:") result = 0 for name in data1.names: d1 = data1.get_var(name) d2 = data2.get_var(name) abs_err = np.max(np.abs(d1.v() - d2.v())) if not np.any(d2.v() == 0): rel_err = np.max(np.abs(d1.v() - d2.v()) / np.abs(d2.v())) print(f"{name:20s} absolute error = {abs_err:10.10g}, relative error = {rel_err:10.10g}") else: print(f"{name:20s} absolute error = {abs_err:10.10g}") if not np.allclose(d1.v(), d2.v(), rtol=rtol): result = "varerr" return result
[docs] def main(): if not (len(sys.argv) == 3 or len(sys.argv) == 4): print(usage) sys.exit(2) file1 = sys.argv[1] file2 = sys.argv[2] s1 = io.read(file1) s2 = io.read(file2) if len(sys.argv) == 3: result = compare(s1.cc_data, s2.cc_data) else: result = compare(s1.cc_data, s2.cc_data, rtol=float(sys.argv[3])) if result == 0: print("SUCCESS: files agree") else: print("ERROR: ", errors[result])
if __name__ == "__main__": main()