Source code for ronswanson.simulation_builder

import json
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import h5py
import numpy as np

import yaml
from omegaconf import MISSING, OmegaConf
from tqdm.auto import tqdm

from ronswanson.utils.color import Colors
from ronswanson.utils.check_complete import check_complete_ids

from .grids import ParameterGrid
from .script_generator import (
    PythonGatherGenerator,
    PythonGenerator,
    SLURMGatherGenerator,
    SLURMGenerator,
)
from .utils.logging import setup_logger

log = setup_logger(__name__)


[docs]@dataclass class SLURMTime: hrs: int = 0 min: int = 10 sec: int = 0
[docs]@dataclass class JobConfig: time: SLURMTime n_cores_per_node: int
[docs]@dataclass class GatherConfig(JobConfig): n_gather_per_core: int
[docs]@dataclass class SimulationConfig(JobConfig): n_mp_jobs: int run_per_node: Optional[int] = None use_nodes: bool = False max_nodes: Optional[int] = None linear_execution: bool = False num_meta_parameters: Optional[int] = None
## structure for file
[docs]@dataclass class JobConfigStructure: time: Optional[SLURMTime] = None n_cores_per_node: Optional[int] = None
[docs]@dataclass class GatherConfigStructure(JobConfigStructure): n_gather_per_core: int = MISSING
[docs]@dataclass class SimulationConfigStructure(JobConfigStructure): n_mp_jobs: int = MISSING run_per_node: Optional[int] = None use_nodes: bool = False max_nodes: Optional[int] = None linear_execution: bool = False
[docs]@dataclass class YAMLStructure: import_line: str = MISSING parameter_grid: str = MISSING out_file: str = MISSING clean: bool = True simulation: SimulationConfigStructure = SimulationConfigStructure() gather: Optional[GatherConfigStructure] = None num_meta_parameters: Optional[int] = None finish_missing: bool = False
[docs]class SimulationBuilder: """ The simulation builder class constructs the scripts needed for building the table model database """
[docs] def __init__( self, parameter_grid: ParameterGrid, out_file: str, import_line: str, simulation_config: SimulationConfig, gather_config: Optional[GatherConfig] = None, num_meta_parameters: Optional[int] = None, clean: bool = True, finish_missing: bool = False, ): """TODO describe function :param parameter_grid: :type parameter_grid: ParameterGrid :param out_file: :type out_file: str :param import_line: :type import_line: str :param simulation_config: :type simulation_config: SimulationConfig :param gather_config: :type gather_config: Optional[GatherConfig] :param num_meta_parameters: :type num_meta_parameters: Optional[int] :param clean: :type clean: bool :returns: """ self._has_complete_params: bool = False self._import_line: str = import_line self._simulation_config: SimulationConfig = simulation_config self._gather_config: Optional[GatherConfig] = gather_config self._out_file: str = out_file self._base_dir: Path = Path(out_file).parent.absolute() self._num_meta_parameters: Optional[int] = num_meta_parameters # write out the parameter file self._parameter_file: Path = self._base_dir / "parameters.yml" parameter_grid.write(str(self._parameter_file)) self._n_outputs: int = len(parameter_grid.energy_grid) self._clean: bool = clean self._n_iterations: int = parameter_grid.n_points self._current_database_size: int = 0 self._finish_missing: bool = finish_missing if not self._finish_missing: self._initialize_database() # if we are using nodes # we need to see how many # we need log.info( f"there are [bold bright_red]{self._n_iterations} iterations [/bold bright_red]" ) if self._simulation_config.use_nodes: self._compute_chunks() else: self._n_nodes: Optional[int] = None self._n_gather_nodes: Optional[int] = None self._generate_python_script() if self._simulation_config.use_nodes: self._generate_slurm_script() output_dir = self._base_dir / "output" if not output_dir.exists(): output_dir.mkdir() log.debug("created the output directory")
[docs] @classmethod def from_yaml(cls, file_name: str) -> "SimulationBuilder": """ Create a simulation setup from a yaml file. """ # check the file structure structure = OmegaConf.structured(YAMLStructure) try: test_structure = OmegaConf.load(file_name) merged = OmegaConf.merge(structure, test_structure) OmegaConf.to_container(merged, throw_on_missing=True) except Exception as e: log.error(e) raise e with Path(file_name).open("r") as f: inputs = yaml.load(stream=f, Loader=yaml.SafeLoader) log.debug("reading setup inputs:") for k, v in inputs.items(): log.debug(f"{k}: {v}") parameter_grid = ParameterGrid.from_yaml(inputs.pop("parameter_grid")) simulation_input = inputs.pop("simulation") if "time" in simulation_input: sim_time = SLURMTime(**simulation_input.pop("time")) else: sim_time = SLURMTime() if "n_cores_per_node" not in simulation_input: if "use_nodes" in simulation_input: if simulation_input["use_nodes"]: log.warning( "you are using nodes but did not specify the number of n_cores_per_node" ) log.warning( "the number of cores will be set to the number of multi-process jobs" ) simulation_input["n_cores_per_node"] = simulation_input["n_mp_jobs"] simulation_config = SimulationConfig(time=sim_time, **simulation_input) gather_config = None if "gather" in inputs: gather_inputs = inputs.pop("gather") gather_time = SLURMTime(**gather_inputs.pop("time")) gather_config = GatherConfig(time=gather_time, **gather_inputs) return cls( parameter_grid=parameter_grid, simulation_config=simulation_config, gather_config=gather_config, **inputs, )
def _initialize_database(self) -> None: if not Path(self._out_file).exists(): with h5py.File(self._out_file, "w") as f: f.attrs["has_been_touched"] = False pg = ParameterGrid.from_yaml(self._parameter_file) # store the parameter names p_name_group = f.create_group("parameter_names") for i, name in enumerate(pg.parameter_names): p_name_group.attrs[f"par{i}"] = name # store the energy grids ene_grp = f.create_group("energy_grid") for i, grid in enumerate(pg.energy_grid): ene_grp.create_dataset( f"energy_grid_{i}", data=grid.grid, compression="gzip", ) # create an empty group for the parameters f.create_dataset( "parameters", shape=(pg.n_points,) + np.array(pg.parameter_names).shape, maxshape=(None,) + np.array(pg.parameter_names).shape, # compression="gzip", ) val_grp: h5py.Group = f.create_group("values") # create an empty data set for the values for i in range(len(pg.energy_grid)): val_grp.create_dataset( f"output_{i}", shape=(pg.n_points,) + pg.energy_grid[i].grid.shape, maxshape=(None,) + pg.energy_grid[i].grid.shape, # compression="gzip", ) f.create_dataset( "run_time", shape=(pg.n_points,), maxshape=(None,) ) if self._num_meta_parameters is not None: meta_grp: h5py.Group = f.create_group("meta") log.debug("detected meta parameters") for i in range(self._num_meta_parameters): meta_grp.create_dataset( f"meta_{i}", shape=(pg.n_points,), maxshape=(None,) ) else: # we need to resize the dataset log.warning( f"There was already a database: [red]{self._out_file}[/red]" ) with h5py.File(self._out_file, "r") as f: has_been_touched = f.attrs["has_been_touched"] if not has_been_touched: log.warning("the database has not been gathered") log.warning("erasing and starting over") Path(self._out_file).unlink() time.sleep(2) self._initialize_database() copy_file_name: str = f"{Path(self._out_file).parent}/{Path(self._out_file).stem}_copy{Path(self._out_file).suffix}" log.warning(f"a copy will be made to [blue]{copy_file_name}[\blue]") shutil.copy(self._out_file, copy_file_name) self._check_completed() with h5py.File(self._out_file, "a") as f: pg = ParameterGrid.from_yaml(self._parameter_file) dataset: h5py.Dataset = f["parameters"] self._current_database_size = dataset.shape[0] log.warning( f"The existing data base had {self._current_database_size} entries" ) dataset.resize( (self._current_database_size + pg.n_points,) + dataset.shape[1:] ) val_grp = f["values"] for i in range(len(pg.energy_grid)): dataset: h5py.Dataset = val_grp[f"output_{i}"] dataset.resize( (self._current_database_size + pg.n_points,) + dataset.shape[1:] ) dataset: h5py.Dataset = f["run_time"] dataset.resize((self._current_database_size + pg.n_points,)) if self._num_meta_parameters is not None: log.debug("detected meta parameters") meta_grp = f["meta"] for i in range(self._num_meta_parameters): dataset: h5py.Dataset = meta_grp[f"meta_{i}"] dataset.resize( (self._current_database_size + pg.n_points,) ) self._n_outputs: int = len(pg.energy_grid) def _check_completed(self) -> None: if Path(self._out_file).exists(): with h5py.File(self._out_file, "r") as f: params = f["parameters"][()] out_file = self._base_dir / "completed_parameters.json" with out_file.open("w") as f: json.dump(params.tolist(), f) self._has_complete_params = True def _compute_complete_ids(self): log.info("seeing how many are missing from the run") complete_ids = check_complete_ids(self._out_file) number_missing = self._n_iterations - len(complete_ids) log.info(f"there were {number_missing} runs") return np.array(complete_ids) def _compute_chunks(self) -> None: if self._finish_missing: complete_ids = self._compute_complete_ids() full = np.array(range(self._n_iterations)) full[complete_ids] = -99 incomplete_ids = full[full >= 0] else: complete_ids = [] # we may only be cleaning up missing runs total_iterations: int = self._n_iterations - len(complete_ids) if self._simulation_config.run_per_node is None: log.debug("Each node will only execute the number of mp jobs") runs_per_node = 1 generator = range(self._simulation_config.n_mp_jobs) n_nodes = np.ceil( total_iterations / self._simulation_config.n_mp_jobs ) else: runs_per_node = self._simulation_config.run_per_node generator = range(runs_per_node) n_nodes = np.ceil(total_iterations / runs_per_node) if self._simulation_config.use_nodes: self._n_nodes = int(n_nodes) log.info( f"we will be using {self._n_nodes} nodes for the simulation" ) # now generate the key files k = 0 key_out = {} for i in tqdm( range(self._n_nodes), desc="computing node layout", colour=Colors.RED.value, ): output = [] for j in generator: if not self._finish_missing: if k < self._n_iterations: output.append(k) else: if k < total_iterations: output.append(int(incomplete_ids[k])) k += 1 key_out[i] = output with open(self._base_dir / "key_file.json", "w") as f: json.dump(key_out, f) # now collect the gather information if self._simulation_config.use_nodes: self._n_gather_nodes = int( np.ceil( self._n_iterations / ( self._gather_config.n_cores_per_node * self._gather_config.n_gather_per_core ) ) ) rank_list = {} n = 0 log.info(f"the gather task will use: {self._n_gather_nodes} nodes") log.debug( f"total_ranks: {self._n_gather_nodes * self._gather_config.n_cores_per_node}" ) log.debug(f"number iterations: {self._n_iterations}") for i in tqdm( range( self._n_gather_nodes * self._gather_config.n_cores_per_node ), desc="computing nodes for gather operation", colour=Colors.GREEN.value, ): core_list = [] for j in range(self._gather_config.n_gather_per_core): if n < self._n_iterations: core_list.append(n) n += 1 rank_list[i] = core_list if not self._finish_missing: with open(self._base_dir / "gather_file.json", "w") as f: json.dump(rank_list, f) def _generate_python_script(self) -> None: py_gen: PythonGenerator = PythonGenerator( "run_simulation.py", self._out_file, str(self._parameter_file), self._base_dir, self._import_line, self._simulation_config.n_mp_jobs, self._n_nodes, self._simulation_config.linear_execution, self._has_complete_params, self._current_database_size, clean=self._clean, ) py_gen.write(str(self._base_dir)) log.info( "[bold green blink]generated:[/bold green blink] run_simulation.py" ) def _generate_slurm_script(self) -> None: multi_script: bool = False if self._simulation_config.max_nodes is not None: if self._n_nodes > self._simulation_config.max_nodes: log.debug("The number of reuested nodes is too large.") multi_script = True n_files = int( np.ceil(self._n_nodes / self._simulation_config.max_nodes) ) start = [] stop = [] current_number = 0 for i in range(n_files): start.append(current_number) next_number = int( (i + 1) * self._simulation_config.max_nodes ) if next_number <= self._n_nodes: stop.append(next_number) current_number = next_number else: stop.append(self._n_nodes) break if multi_script: for i, (a, b) in enumerate(zip(start, stop)): file_name = f"run_simulation_{i}.sh" slurm_gen: SLURMGenerator = SLURMGenerator( file_name, self._simulation_config.n_mp_jobs, self._simulation_config.n_cores_per_node, b, self._simulation_config.time.hrs, self._simulation_config.time.min, self._simulation_config.time.sec, node_start=a, ) slurm_gen.write(str(self._base_dir)) log.info( f"[bold green blink]generated:[/bold green blink] {file_name}" ) else: slurm_gen: SLURMGenerator = SLURMGenerator( "run_simulation.sh", self._simulation_config.n_mp_jobs, self._simulation_config.n_cores_per_node, self._n_nodes, self._simulation_config.time.hrs, self._simulation_config.time.min, self._simulation_config.time.sec, ) slurm_gen.write(str(self._base_dir)) log.info( "[bold green blink]generated:[/bold green blink] run_simulations.sh" ) if not self._finish_missing: slurm_gen: SLURMGatherGenerator = SLURMGatherGenerator( "gather_results.sh", self._gather_config.n_cores_per_node, self._n_gather_nodes, self._gather_config.time.hrs, self._gather_config.time.min, self._gather_config.time.sec, ) slurm_gen.write(str(self._base_dir)) log.info( "[bold green blink]generated:[/bold green blink] gather_results.sh" ) python_gather_gen: PythonGatherGenerator = PythonGatherGenerator( "gather_results.py", database_file_name=self._out_file, current_size=self._current_database_size, n_outputs=self._n_outputs, clean=self._clean, num_meta_parameters=self._num_meta_parameters, ) python_gather_gen.write(str(self._base_dir)) log.info( "[bold green blink]generated:[/bold green blink] gather_results.py" )