diff options
Diffstat (limited to 'hittekaart-py')
| -rw-r--r-- | hittekaart-py/src/lib.rs | 126 | 
1 files changed, 90 insertions, 36 deletions
diff --git a/hittekaart-py/src/lib.rs b/hittekaart-py/src/lib.rs index fe97aa7..2f99793 100644 --- a/hittekaart-py/src/lib.rs +++ b/hittekaart-py/src/lib.rs @@ -123,7 +123,7 @@ impl Storage {  }  impl Storage { -    fn open(&self) -> PyResult<Box<dyn hittekaart::storage::Storage>> { +    fn open(&self) -> PyResult<Box<dyn hittekaart::storage::Storage + Send>> {          match self.0 {              StorageType::Folder(ref path) => {                  let storage = hittekaart::storage::Folder::new(path.clone()); @@ -194,13 +194,72 @@ impl TilehuntRenderer {      }  } +/// Tile generation settings. +/// +/// This contains everything that is overarching to the renderers and output modules, such as zoom +/// levels and thread count. +#[pyclass] +struct Settings { +    #[pyo3(get, set)] +    /// Smallest zoom level that will be generated. +    min_zoom: u32, + +    #[pyo3(get, set)] +    /// Largest zoom level that will be generated. +    max_zoom: u32, + +    #[pyo3(get, set)] +    /// How many threads to use for generation. +    /// +    /// A count of 0 will automatically use as many threads as you have CPU cores. +    threads: u32, +} + +#[pymethods] +impl Settings { +    #[new] +    #[pyo3(signature = (min_zoom = 1, max_zoom = 19, threads = 0))] +    fn new(min_zoom: u32, max_zoom: u32, threads: u32) -> Settings { +        Settings { +            min_zoom, +            max_zoom, +            threads, +        } +    } + +    fn __repr__(&self) -> String { +        format!( +            "Settings(min_zoom={}, max_zoom={}, threads={})", +            self.min_zoom, self.max_zoom, self.threads +        ) +    } +} + +macro_rules! dispatch_generate { +    ($settings:expr, $tracks:expr, $renderer:expr, $storage:expr => <$type:ty>, $(<$types:ty>,)*) => { +        if let Ok(r) = $renderer.downcast::<$type>() { +            do_generate($settings, $tracks, &r.borrow().inner, $storage) +        } else { +            dispatch_generate!($settings, $tracks, $renderer, $storage => $(<$types>,)*) +        } +    }; + +    ($settings:expr, $tracks:expr, $renderer:expr, $storage:expr =>) => { +        Err(PyTypeError::new_err( +            "Expected a HeatmapRenderer, MarktileRenderer or TilehuntRenderer", +        )) +    }; +} +  /// Generate the heatmap.  /// +/// * settings are the rendering Settings  /// * items is an iterable of Track  /// * renderer should be one of the renderers (such as HeatmapRenderer)  /// * storage is the Storage output  #[pyfunction]  fn generate( +    settings: &Bound<'_, Settings>,      items: &Bound<'_, PyAny>,      renderer: &Bound<'_, PyAny>,      storage: &Bound<'_, Storage>, @@ -212,53 +271,48 @@ fn generate(          tracks.push(item.extract::<Track>()?.inner);      } -    if let Ok(r) = renderer.downcast::<HeatmapRenderer>() { -        do_generate(tracks, &r.borrow().inner, &mut *storage.borrow().open()?) -    } else if let Ok(r) = renderer.downcast::<MarktileRenderer>() { -        do_generate(tracks, &r.borrow().inner, &mut *storage.borrow().open()?) -    } else if let Ok(r) = renderer.downcast::<TilehuntRenderer>() { -        do_generate(tracks, &r.borrow().inner, &mut *storage.borrow().open()?) -    } else { -        Err(PyTypeError::new_err("Expected a HeatmapRenderer, MarktileRenderer or TilehuntRenderer")) +    let settings = &*settings.borrow(); + +    // We cannot easily do dynamic dispatch here, because Renderer::Prepared exists. Maybe this can +    // change in the future, but for now we have to stick to this: +    dispatch_generate! { +        settings, tracks, renderer, &mut *storage.borrow().open()? => +            <HeatmapRenderer>, +            <MarktileRenderer>, +            <TilehuntRenderer>,      }  }  fn do_generate<R: Renderer>( +    settings: &Settings,      tracks: Vec<Vec<Coordinates>>,      renderer: &R, -    storage: &mut dyn hittekaart::storage::Storage, +    storage: &mut (dyn hittekaart::storage::Storage + Send),  ) -> PyResult<()> {      storage.prepare().map_err(|e| err_to_py(&e))?; -    for zoom in 0..=19 { -        let counter = -            renderer::prepare(renderer, zoom, &tracks, || Ok(())).map_err(|e| err_to_py(&e))?; +    let pool = rayon::ThreadPoolBuilder::new() +        .num_threads(settings.threads.try_into().unwrap()) +        .build() +        .map_err(|e| err_to_py(&e))?; -        storage.prepare_zoom(zoom).map_err(|e| err_to_py(&e))?; +    pool.install(|| { +        for zoom in settings.min_zoom..=settings.max_zoom { +            let counter = +                renderer::prepare(renderer, zoom, &tracks, || Ok(())).map_err(|e| err_to_py(&e))?; -        renderer::colorize(renderer, counter, |rendered_tile| { -            storage.store(zoom, rendered_tile.x, rendered_tile.y, &rendered_tile.data)?; -            Ok(()) -        }) -        .map_err(|e| err_to_py(&e))?; -    } -    storage.finish().map_err(|e| err_to_py(&e))?; +            storage.prepare_zoom(zoom).map_err(|e| err_to_py(&e))?; -    Ok(()) -} +            renderer::colorize(renderer, counter, |rendered_tile| { +                storage.store(zoom, rendered_tile.x, rendered_tile.y, &rendered_tile.data)?; +                Ok(()) +            }) +            .map_err(|e| err_to_py(&e))?; +        } +        storage.finish().map_err(|e| err_to_py(&e))?; -/// Set the number of threads that hittekaart will use. -/// -/// Note that this is a global function, it will affect all subsequent calls. -/// -/// Note further that you may only call this function once, at startup. Calls after the thread pool -/// has been initialized (e.g. via a generate() or set_threads() call) will raise an exception. -#[pyfunction] -fn set_threads(threads: usize) -> PyResult<()> { -    rayon::ThreadPoolBuilder::new() -        .num_threads(threads) -        .build_global() -        .map_err(|e| err_to_py(&e)) +        Ok(()) +    })  }  /// Python bindings for the hittekaart heatmap tile generator. @@ -278,8 +332,8 @@ fn hittekaart_py(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {      m.add_class::<MarktileRenderer>()?;      m.add_class::<TilehuntRenderer>()?;      m.add_class::<Storage>()?; +    m.add_class::<Settings>()?;      m.add_function(wrap_pyfunction!(generate, m)?)?; -    m.add_function(wrap_pyfunction!(set_threads, m)?)?;      m.add("HitteError", py.get_type::<HitteError>())?;      Ok(())  }  | 
