Source code for openalea.cnwheat.tools

# -*- coding: latin-1 -*-

import os
import sys
from itertools import cycle
import warnings
import logging
import logging.config
import json

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

"""
    cnwheat.tools
    ~~~~~~~~~~~~~

    This module provides tools to help for the validation of the outputs: 
    
        * plot of multiple variables on the same graph, 
        * set up of loggers,
        * quantitative comparison test,
        * and progress-bar to follow the evolution of long simulations.  

"""

OUTPUTS_INDEXES = ['t', 'plant', 'axis', 'metamer', 'organ', 'element']  #: All the possible indexes of CN-Wheat outputs


[docs] class DataWarning(UserWarning): """Raised when there is no data to plot for a variable.""" def __init__(self, variable, keys): self.message = 'No data to plot for variable {} at {}.'.format(variable, keys) def __str__(self): return repr(self.message)
# show all DataWarning (not only the first one which occurred) warnings.simplefilter('always', DataWarning)
[docs] def plot_cnwheat_ouputs(outputs, x_name, y_name, x_label='', y_label='', x_lim=None, title=None, filters={}, plot_filepath=None, colors=[], linestyles=[], explicit_label=True, kwargs={}): """Plot `outputs`, with x=`x_name` and y=`y_name`. The general algorithm is: * find the scale of `outputs` and keep only the needed columns, * apply `filters` to `outputs` and make groups according to the scale, * plot each group as a new line, * save or display the plot. :param pandas.DataFrame outputs: The outputs of CN-Wheat. :param str x_name: x axis of the plot. :param str y_name: y axis of the plot. :param str x_label: The x label of the plot. Default is ''. :param str or unicode y_label: The y label of the plot. Default is ''. :param float x_lim: the x-axis limit. :param str title: the title of the plot. If None (default), create a title which is the concatenation of `y_name` and each scales which cardinality is one. :param dict filters: A dictionary whose keys are the columns of `outputs` for which we want to apply a specific filter. These columns can be one or more element of :const:`OUTPUTS_INDEXES`. The value associated to each key is a criteria that the rows of `outputs` must satisfy to be plotted. The values can be either one value or a list of values. If no value is given for any column, then all rows are plotted (default). :param list colors: The colors for lines. If empty, let matplotlib default line colors. :param list linestyles: The styles for lines. If empty, let matplotlib default line styles. :param str plot_filepath: The file path to save the plot. If `None`, do not save the plot but display it. :param bool explicit_label: True: makes the line label from concatenation of each scale id (default). - False: makes the line label from concatenation of scales containing several distinct elements. :param dict kwargs: key arguments to be passed to matplolib :Examples: >>> import pandas as pd >>> cnwheat_output_df = pd.read_csv('cnwheat_output.csv') # in this example, 'cnwheat_output.csv' must contain at least the columns 't' and 'Conc_Sucrose'. >>> plot(cnwheat_output_df, x_name = 't', y_name = 'Conc_Sucrose', x_label='Time (Hour)', y_label=u'[Sucrose] (µmol g$^{-1}$ mstruct)', title='{} = f({})'.format('Conc_Sucrose', 't'), filters={'plant': 1, 'axis': 'MS', 'organ': 'Lamina', 'element': 1}) """ # finds the scale of `outputs` group_keys = [key for key in OUTPUTS_INDEXES if key in outputs and key != x_name and key != y_name] # make a group_keys with first letter of each key in upper case group_keys_upper = [group_key[0].upper() + group_key[1:] for group_key in group_keys] # create a mapping to associate each key to its index in group_keys group_keys_mapping = dict([(key, index) for (index, key) in enumerate(group_keys)]) # keep only the needed columns (to make the grouping faster) outputs = outputs[group_keys + [x_name, y_name]] # apply filters to outputs for key, value in filters.items(): if key in outputs: # convert to list if needed try: _ = iter(value) except TypeError: values = [value] else: values = value # handle strings too if isinstance(values, str): values = [values] # select data from outputs outputs = outputs[outputs[key].isin(values)] # do not plot if there is nothing to plot if outputs[y_name].isnull().all(): return # compute the cardinality of each group keys and create the title if needed subtitle_groups = [] labels_groups = [] for i in range(len(group_keys)): group_key = group_keys[i] group_cardinality = outputs[group_key].nunique() if group_cardinality == 1: group_value = outputs[group_key][outputs.first_valid_index()] subtitle_groups.append('{}: {}'.format(group_keys_upper[i], group_value)) else: labels_groups.append(group_key) if title is None: # we need to create the title title = y_name + '\n' + ' - '.join(subtitle_groups) # makes groups according to the scale outputs_grouped = outputs.groupby(group_keys) # plots each group as a new line fig, ax = plt.subplots() matplot_colors_cycler = cycle(colors) matplot_linestyles_cycler = cycle(linestyles) for outputs_group_name, outputs_group in outputs_grouped: line_label_list = [] if explicit_label: # concatenate the keys of the group name line_label_list.extend(['{}: {}'.format(group_keys_upper[group_keys_mapping[output_group_name]], outputs_group_name) for output_group_name in outputs_group_name]) else: # construct a label with only the essential keys of the group name ; the essential keys are those for which cardinality is non zero for label_group in labels_groups: label_group_index = group_keys_mapping[label_group] line_label_list.append('{}: {}'.format(group_keys_upper[label_group_index], outputs_group_name[label_group_index])) kwargs['label'] = ' - '.join(line_label_list) # apply user colors try: color = next(matplot_colors_cycler) except StopIteration: pass else: kwargs['color'] = color # apply user lines style try: linestyle = next(matplot_linestyles_cycler) except StopIteration: pass else: kwargs['linestyle'] = linestyle # plot the line ax.plot(outputs_group[x_name], outputs_group[y_name], **kwargs) ax.set_ylim(bottom=0.) if x_lim is not None: ax.set_xlim(left=0, right=x_lim) else: ax.set_xlim(left=0) ax.set_xlabel(x_label) ax.set_ylabel(y_label) if kwargs['label']: ax.legend(prop={'size': 6}, framealpha=0.5, loc='center left', bbox_to_anchor=(1, 0.815), borderaxespad=0.) ax.set_title(title) plt.tight_layout() if plot_filepath is None: # display the plot plt.show() else: # save the plot plt.savefig(plot_filepath, dpi=200, format='PNG', bbox_inches='tight') plt.close()
[docs] def setup_logging(config_filepath='logging.json', level=logging.INFO, log_model=False, log_compartments=False, log_derivatives=False, remove_old_logs=False): """Setup logging configuration. :param str config_filepath: The file path of the logging configuration. :param int level: The global level of the logging. Use either `logging.DEBUG`, `logging.INFO`, `logging.WARNING`, `logging.ERROR` or `logging.CRITICAL`. :param bool log_model: if `True`, log the messages from :mod:`cnwheat.model`. `False` otherwise. :param bool log_compartments: if `True`, log the values of the compartments. `False` otherwise. :param bool log_derivatives: if `True`, log the values of the derivatives. `False` otherwise. :param bool remove_old_logs: if `True`, remove all files in the logs directory documented in `config_filepath`. """ if os.path.exists(config_filepath): with open(config_filepath, 'r') as f: config = json.load(f) if remove_old_logs: logs_dir = os.path.dirname(os.path.abspath(config['handlers']['file_info']['filename'])) for logs_file in os.listdir(logs_dir): os.remove(os.path.join(logs_dir, logs_file)) logging.config.dictConfig(config) else: logging.basicConfig() root_logger = logging.getLogger() root_logger.setLevel(level) cnwheat_model_logger = logging.getLogger('cnwheat.model') cnwheat_model_logger.disabled = not log_model # set to False to log messages from openalea.cnwheat.model logging.getLogger('cnwheat.compartments').disabled = not log_compartments # set to False to log the compartments logging.getLogger('cnwheat.derivatives').disabled = not log_derivatives # set to False to log the derivatives
[docs] def compare_actual_to_desired(data_dirpath, actual_data_df, desired_data_filename, actual_data_filename=None, precision=4, overwrite_desired_data=False): """Compare difference = actual_data_df - desired_data_df to tolerance = 10**-precision * (1 + abs(desired_data_df)) where desired_data_df = pd.read_csv(os.path.join(data_dirpath, desired_data_filename)) If difference > tolerance, then raise an AssertionError. :param str data_dirpath: The path of the directory where to find the data to compare. :param pandas.DataFrame actual_data_df: The computed data. :param str desired_data_filename: The file name of the expected data. :param str actual_data_filename: If not None, save the computed data to `actual_data_filename`, in directory `data_dirpath`. Default is None. :param int precision: The precision to use for the comparison. Default is `4`. :param bool overwrite_desired_data: If True the comparison between actual and desired data is not run. Instead, the desired data will be overwritten using actual data. To be used with caution. """ relative_tolerance = 10**-precision absolute_tolerance = relative_tolerance # read desired data desired_data_filepath = os.path.join(data_dirpath, desired_data_filename) desired_data_df = pd.read_csv(desired_data_filepath) if actual_data_filename is not None: # save actual outputs to CSV file actual_data_filepath = os.path.join(data_dirpath, actual_data_filename) actual_data_df.to_csv(actual_data_filepath, na_rep='NA', index=False, float_format='%.{}f'.format(precision)) if overwrite_desired_data: warnings.warn('!!! Unit test is running with overwrite_desired_data !!!') desired_data_filepath = os.path.join(data_dirpath, desired_data_filename) actual_data_df.to_csv(desired_data_filepath, na_rep='NA', index=False) else: # keep only numerical data (np.testing can compare only numerical data) for column in ('axis', 'organ', 'element', 'is_growing'): if column in desired_data_df.columns: del desired_data_df[column] del actual_data_df[column] # convert the actual outputs to floats actual_data_df = actual_data_df.astype(np.float64) # compare actual data to desired data np.testing.assert_allclose(actual_data_df.values, desired_data_df.values, relative_tolerance, absolute_tolerance)
[docs] class ProgressBarError(Exception): pass
[docs] class ProgressBar(object): """ Display a console progress bar. """ def __init__(self, bar_length=20, title='', block_character='#', uncomplete_character='-'): if bar_length <= 0: raise ProgressBarError('bar_length <= 0') self.bar_length = bar_length #: the number of blocks in the progress bar. MUST BE GREATER THAN ZERO ! self.t_max = 1 #: the maximum t that the progress bar can display. MUST BE GREATER THAN ZERO ! self.block_interval = 1 #: the time interval of each block. MUST BE GREATER THAN ZERO ! self.last_upper_t = 0 #: the last upper t displayed by the progress bar self.progress_mapping = {} #: a mapping to optimize the refresh rate self.title = title #: the title to write on the left side of the progress bar self.block_character = block_character #: the character to represent a block self.uncomplete_character = uncomplete_character #: the character to represent the uncompleted part of the progress bar
[docs] def set_t_max(self, t_max): """"Set :attr:`t_max` and update other attributes accordingly. """ if t_max <= 0: raise ProgressBarError('t_max <= 0') self.t_max = t_max self.block_interval = self.t_max / self.bar_length self.last_upper_t = 0 self.progress_mapping.clear()
[docs] def update(self, t): """Update the progress bar if needed. """ t = min(t, self.t_max) if t < self.last_upper_t: return else: self.last_upper_t = t t_inf = t // self.block_interval * self.block_interval if t_inf not in self.progress_mapping: progress = t / self.t_max block = int(round(self.bar_length * progress)) text = "\r{0}: [{1}] {2:>5d}% ".format(self.title, self.block_character * block + self.uncomplete_character * (self.bar_length - block), int(progress*100)) self.progress_mapping[t_inf] = text sys.stdout.write(self.progress_mapping[t_inf]) sys.stdout.flush()