Source code for STJ_PV.input_data

# -*- coding: utf-8 -*-
"""Generate or load input data for STJ Metric."""
import os
import numpy as np
import pkg_resources
import datetime as dt
import xarray as xr
# Dependent code
import STJ_PV.utils as utils

__author__ = "Penelope Maher, Michael Kelleher"


[docs]def package_data(relpath, file_name): """Get data relative to this installed package. Generally used for the sample data.""" _data_dir = pkg_resources.resource_filename('STJ_PV', relpath) return xr.open_dataset(os.path.join(_data_dir, file_name))
[docs]class InputData: """ Contains the relevant input data and routines for an JetFindRun. Parameters ---------- jet_find : :py:meth:`~STJ_PV.run_stj.JetFindRun` Object containing properties about the metric calculation to be performed. Used to locate correct files, and variables within those files. year : int, optional Year of data to load, not used when all years are in a single file """ # For the default InputData class, there are no required fields # this should be overridden in child classes for each metric load_vars = [] def __init__(self, props, date_s=None, date_e=None): """Initialize InputData object, using JetFindRun class.""" self.props = props self.data_cfg = props.data_cfg if date_s is not None: self.year = date_s.year else: self.year = None self.in_data = {} self.out_data = {} self.sel = { self.data_cfg['time']: slice(None), self.data_cfg['lev']: slice(None), self.data_cfg['lat']: slice(None), self.data_cfg['lon']: slice(None), } self.chunk = { self.data_cfg['time']: None, self.data_cfg['lev']: None, self.data_cfg['lat']: None, self.data_cfg['lon']: None, } if date_s is not None or date_e is not None: self.sel[self.data_cfg['time']] = slice(date_s, date_e) self._select_setup() def _select_setup(self): """Set up selection dictionary to be passed to xarray.DataArray.sel.""" cfg = self.data_cfg for cvar in ['lon', 'lev', 'lat']: # In the data config file, the start and end values will be # labeled as lon_s, lon_e (for longitude as an example) _start = '{}_s'.format(cvar) _end = '{}_e'.format(cvar) _skip = '{}_skp'.format(cvar) _slice = ( cfg.get(_start, None), cfg.get(_end, None), cfg.get(_skip, None), ) self.sel[cfg[cvar]] = slice(*_slice) # if _start and _end in cfg: # self.sel[cfg[cvar]] = slice(cfg[_start], cfg[_end]) def _load_data(self): """Load all required data.""" for data_var in self.load_vars: try: self._load_one_file(data_var) except (KeyError, FileNotFoundError): self.props.log.info('FILE FOR %s NOT FOUND', data_var) def _chunk_data(self, var): """Re-chunk input data to ideal size.""" self.in_data[var] = self.in_data[var].chunk(self.chunk) def _set_chunks(self, data, max_size=3e6, excldims=('lev', 'lat')): """Get ideal-ish chunks in a different way.""" shape = data.shape npoints = np.prod(shape) excl_n = [self.data_cfg[excldim] for excldim in excldims] dims = [ coord for coord in data.coords if coord in data.dims and coord not in excl_n ] chunks = {dim: data[dim].shape[0] for dim in dims} divis = {dim: 1 for dim in dims} n_iter = 0 while npoints > max_size and n_iter < 10: _pts = np.prod([data[name].shape[0] for name in excl_n]) for dim in dims: divis[dim] *= 2 _pts *= chunks[dim] // divis[dim] npoints = _pts n_iter += 1 self.props.log.info(f' Chunking took {n_iter} iterations') chunks_out = {dim: chunks[dim] // divis[dim] for dim in dims} for exname in excl_n: chunks_out[exname] = data[exname].shape[0] for dim in chunks_out: # Quick sanity check to make sure we don't divide by 0 if chunks_out[dim] == 0: chunks_out[dim] = 1 self.props.log.info(f' - {dim}: {chunks_out[dim]}') self.props.log.info(f' Chunk size: {npoints}') self.chunk = chunks_out def _load_one_file(self, var, file_var=None): """Load a single netCDF file as an xarray.Dataset.""" cfg = self.data_cfg vname = cfg[var] # Use this to set the file variable name (look for uwnd in ipv file) if file_var is None: file_var = var # Format the name of the file, join it with the path, open it try: file_name = cfg['file_paths'][file_var].format(year=self.year) except KeyError: file_name = cfg['file_paths']['all'].format(year=self.year) self.props.log.info( 'OPEN: {}'.format(os.path.join(cfg['path'], file_name)) ) try: nc_file = xr.open_dataset(os.path.join(cfg['path'], file_name)) except FileNotFoundError: nc_file = package_data(cfg['path'], file_name) self.in_data[var] = nc_file[vname].sel(**self.sel) _fails = 0 while self.in_data[var][cfg['time']].shape[0] == 0 and _fails < 15: # Update the time slice so that it covers potential mis-match # between how days are requested and how they're stored in the # netCDF file (e.g. ask for 2013-07-14 00:00, but the netCDF # file has 2013-07-14 09:00) _day = dt.timedelta(hours=23) self.sel[cfg['time']] = slice( self.sel[cfg['time']].start, self.sel[cfg['time']].stop + _day ) self.in_data[var] = nc_file[vname].sel(**self.sel) self.props.log.info( 'UPDATING TIME SLICE BY 1 DAY %s', (self.sel[cfg['time']].stop.strftime('%Y-%m-%d %HZ')), ) # Iterate, but don't get stuck here _fails += 1 if all([self.chunk[var] is None for var in self.chunk]): self._set_chunks(self.in_data[var]) self._chunk_data(var)
[docs] def get_data(self): """Get a single xarray.Dataset of required components for metric.""" data = xr.Dataset( self.out_data, attrs={'cfg': self.data_cfg, 'year': self.year} ) return data
[docs] def write_data(self, out_file=None): """Write netCDF file of the out_data.""" if out_file is None: file_name = self.data_cfg['file_paths']['ipv'].format( year=self.year ) out_file = os.path.join(self.data_cfg['wpath'], file_name) if not os.access(out_file, os.W_OK): write_dir = pkg_resources.resource_filename( 'STJ_PV', self.data_cfg['wpath'] ) out_file = os.path.join(write_dir, file_name) self.props.log.info('Begin writing to %s', out_file) self.get_data().to_netcdf(out_file) self.props.log.info('Finished writing to %s', out_file)
[docs]class InputDataSTJPV(InputData): """ Contains the relevant input data and routines for an STJPV jet find. Parameters ---------- jet_find : :py:meth:`~STJ_PV.run_stj.JetFindRun` Object containing properties about the metric calculation to be performed. Used to locate correct files, and variables within those files. year : int, optional Year of data to load, not used when all years are in a single file """ load_vars = ['uwnd', 'vwnd', 'tair', 'epv', 'ipv'] def __init__(self, props, date_s=None, date_e=None): """Initialize InputData object, using JetFindRun class.""" super(InputDataSTJPV, self).__init__(props, date_s, date_e) # Each STJPV input data _must_ have u-wind and isentropic pv # but _might_ also need the v-wind and air temperature to # calculate isentropic pv self.out_data = {'uwnd': None, 'ipv': None} self.th_lev = None def _find_pv_update(self): """Determine if PV needs to be computed/re-computed.""" pv_file_name = self.data_cfg['file_paths']['ipv'].format( year=self.year ) pv_file = os.path.join(self.data_cfg['wpath'], pv_file_name) return self.props.config['update_pv'] or not os.path.exists(pv_file) def _calc_ipv(self): # Shorthand for configuration dictionary cfg = self.data_cfg dimvars = {cvar: cfg[cvar] for cvar in ['time', 'lev', 'lat', 'lon']} if not self.in_data: self._load_data() self.props.log.info('Starting IPV calculation') # calculate IPV if cfg['ztype'] == 'pres': if 'epv' not in self.in_data: self.props.log.info('USING U, V, T TO COMPUTE IPV') ipv, _, uwnd = utils.xripv( self.in_data['uwnd'], self.in_data['vwnd'], self.in_data['tair'], dimvars=dimvars, th_levels=self.props.th_levels, ) else: self.props.log.info('USING ISOBARIC PV TO COMPUTE IPV') thta = utils.xrtheta(self.in_data['tair'], pvar=cfg['lev']) ipv = utils.xrvinterp( self.in_data['epv'], thta, self.props.th_levels, levname=cfg['lev'], newlevname=cfg['lev'], ) uwnd = utils.xrvinterp( self.in_data['uwnd'], thta, self.props.th_levels, levname=cfg['lev'], newlevname=cfg['lev'], ) self.out_data['ipv'] = ipv self.out_data['uwnd'] = uwnd self.th_lev = self.props.th_levels elif cfg['ztype'] == 'theta': ipv = utils.xripv_theta( self.in_data['uwnd'], self.in_data['vwnd'], self.in_data['pres'], dimvars, ) self.out_data['ipv'] = ipv self.out_data['uwnd'] = self.in_data['uwnd'] ipv_attrs = { 'units': '10^-6 PVU', 'standard_name': 'isentropic_potential_vorticity', 'descr': 'Potential vorticity on isentropic levels', } uwnd_attrs = { 'units': 'm s-1', 'standard_name': 'zonal_wind_component', 'descr': 'Zonal wind on isentropic levels', } self.out_data['ipv'] = self.out_data['ipv'].assign_attrs(ipv_attrs) self.out_data['uwnd'] = self.out_data['uwnd'].assign_attrs(uwnd_attrs) self.props.log.info('Finished calculating IPV') def _load_ipv(self): """Open IPV and Isentropic U-wind file(s), load into self.out_data.""" file_name = self.data_cfg['file_paths']['ipv'].format(year=self.year) in_file = os.path.join(self.data_cfg['wpath'], file_name) self.props.log.info("LOAD IPV FROM FILE: {}".format(in_file)) self._load_one_file('ipv') try: # Check for uwind in the IPV file first self._load_one_file('uwnd', file_var='ipv') except KeyError: # But fall back on the uwind file self._load_one_file('uwnd') self.out_data = self.in_data self.th_lev = self.in_data['ipv'][self.data_cfg['lev']] def _write_ipv(self): """Write generated IPV data to file.""" pv_file_name = self.data_cfg['file_paths']['ipv'].format( year=self.year ) pv_file = os.path.join(self.data_cfg['wpath'], pv_file_name) self.props.log.info('WRITING PV FILE %s', pv_file) encoding = {'zlib': True, 'complevel': 9} dsout = xr.Dataset(self.out_data) dsout[self.data_cfg['lev']] = dsout[self.data_cfg['lev']].assign_attrs( {'units': 'K', 'standard_name': 'potential_temperature'} ) dsout.encoding = dict((var, encoding) for var in dsout.data_vars) dsout.to_netcdf(pv_file, encoding=dsout.encoding) self.props.log.info('DONE WRITING PV FILE')
[docs] def get_data(self): """Load and compute required data, return xarray.Dataset.""" if 'force_write' in self.props.config: force_write = self.props.config['force_write'] else: force_write = False if self._find_pv_update(): self._calc_ipv() if self.sel[self.data_cfg['time']] == slice(None) or force_write: self._write_ipv() else: self._load_ipv() if self.th_lev[0] > self.th_lev[-1]: for data_var in ['uwnd', 'ipv']: self.out_data[data_var] = self.out_data[data_var][:, ::-1] self.th_lev = self.th_lev[::-1] return xr.Dataset( self.out_data, attrs={'cfg': self.data_cfg, 'year': self.year} )
[docs]class InputDataUWind(InputData): """ Contains the relevant input data and routines for an STJPV jet find. Parameters ---------- jet_find : :py:meth:`~STJ_PV.run_stj.JetFindRun` Object containing properties about the metric calculation to be performed. Used to locate correct files, and variables within those files. year : int, optional Year of data to load, not used when all years are in a single file """ def __init__(self, props, date_s=None, date_e=None, vwnd=False): """Initialize InputData object, using JetFindRun class.""" self.load_vars = ['uwnd'] if vwnd: self.load_vars.append("vwnd") super(InputDataUWind, self).__init__(props, date_s, date_e) # Each UWind input data _must_ have u-wind # but _might_ also need the pressure calculate isobaric uwind self.out_data = {} for var_name in self.load_vars: self.out_data[var_name] = None def _calc_interp(self, var_name): lev = self.props.p_levels data_interp = utils.xrvinterp( self.in_data[var_name], self.in_data['pres'], lev, levname=self.data_cfg['lev'], newlevname='pres', ) self.out_data[var_name] = data_interp
[docs] def get_data(self): """Load and compute (if needed) U-Wind on selected pressure level.""" if self.data_cfg == 'theta': self.load_vars.append('pres') self._load_data() for var_name in self.out_data: self._calc_interp(var_name) else: self._load_data() self.out_data = self.in_data return xr.Dataset( self.out_data, attrs={'cfg': self.data_cfg, 'year': self.year} )