Source code for zdm.data_class

"""
Base dataclass utilities for parameter management.

This module provides base classes used by the parameter dataclasses
in the zdm package. These provide common functionality for serialization,
dictionary access, and parameter metadata handling.

Classes
-------
myDataClass
    Base class for parameter dataclasses with metadata access methods.
myData
    Base class for composite data objects containing multiple dataclasses.
"""

import numpy as np
import json
import pandas

from zdm import io

from dataclasses import dataclass, field, asdict
# Add a few methods to be shared by them all
[docs] @dataclass class myDataClass: @property def fields(self): return list(self.__dataclass_fields__.keys())
[docs] def meta(self, attribute_name): return self.__dataclass_fields__[attribute_name].metadata
[docs] def chk_options(self, attribute_name): options = self.__dataclass_fields__[attribute_name].metadata['options']
[docs] class myData:
[docs] def __init__(self): self.set_dataclasses() self.set_params()
[docs] @classmethod def from_dict(cls, param_dict:dict): slf = cls() # Fill em up slf.update_param_dict(param_dict) # return slf
[docs] @classmethod def from_jsonfile(cls, jfile:str): """ Load from a JSON file Args: jfile (str): name of the JSON file Returns: myDataClass: """ json_dict = io.process_jfile(jfile) return cls.from_dict(json_dict)
[docs] @classmethod def from_jsonstr(cls, jsonstr:str): json_dict = json.loads(jsonstr) return cls.from_dict(json_dict)
[docs] def set_dataclasses(self): pass
[docs] def set_params(self): """ Generate a simple dict for parameters """ # Look-up dict or convenience self.params = {} for dc_key in self.__dict__.keys(): if dc_key == 'params': continue for param in self[dc_key].__dict__.keys(): self.params[param] = dc_key
[docs] def __getitem__(self, attrib:str): """Enables dict like access to the state Args: attrib (str): name of the attribute Returns: ?: Value of the attribute requested """ return getattr(self, attrib)
[docs] def update_param_dict(self, params:dict): for key in params.keys(): idict = params[key] for ikey in idict.keys(): # Set self.update_param(ikey, params[key][ikey])
[docs] def update_params(self, params:dict): """ Update the state parameters using the input dict Args: params (dict): New parameters+values """ for key in params.keys(): self.update_param(key, params[key])
[docs] def update_param(self, param:str, value): """ Update the value of a single parameter Args: param (str): name of the parameter value (?): value """ DC = self.params[param] setattr(self[DC], param, value)
[docs] def to_dict(self): """ Generate a dict holding all of the object parameters Returns: dict: [description] """ items = [] for key in self.params.keys(): items.append(self.params[key]) uni_items = np.unique(items) # state_dict = {} for uni_item in uni_items: state_dict[uni_item] = asdict(getattr(self, uni_item)) # Return return state_dict
[docs] def write(self, outfile:str): """ Write the parameters to a JSON file Args: outfile (str): name of output file """ state_dict = self.to_dict() io.savejson(outfile, state_dict, overwrite=True, easy_to_read=True)
[docs] def vet(self, obj, dmodel:dict, verbose=True): """ Vet the input object against its data model Args: obj (dict or pandas.DataFrame): Instance of the data model dmodel (dict): Data model verbose (bool): Print when something doesn't check Returns: tuple: chk (bool), disallowed_keys (list), badtype_keys (list) """ chk = True # Loop on the keys disallowed_keys = [] badtype_keys = [] for key in obj.keys(): # In data model? if not key in dmodel.keys(): disallowed_keys.append(key) chk = False if verbose: print("Disallowed key: {}".format(key)) # Check data type iobj = obj[key].values if isinstance(obj, pandas.DataFrame) else obj[key] if not isinstance(iobj, dmodel[key]['dtype']): badtype_keys.append(key) chk = False if verbose: print("Bad key type: {}".format(key)) # Return return chk, disallowed_keys, badtype_keys
def __repr__(self) -> str: return json.dumps(self.to_dict(), sort_keys=True, indent=4, separators=(',', ': '))