Parallelized batch fitting with Dask¶
This example demonstrates how to fit a set of spectra in parallel, to make use of multiple CPU cores available on a system.
To do this, it relies on the use of Dask
to create a distributed.LocalCluster
to sumbit work to.
First we create a sample dataset to fit using the create_variables
and create_spectrum
functions.
These we then submit to the cluster to fit using the client.map
call.
The client.gather(futures)
call then collects our results, and the call will block untill the fitting is done.
Once we have the results, these are visualized using plotly
and panel
, so we can inspect them individually.
In [1]:
Copied!
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import Moose
import lmfit
import distributed
from functools import partial
import plotly.graph_objects as go
import panel as pn
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
rng = np.random.default_rng(0)
pn.extension('plotly')
x = np.linspace(320,390,2000)
db = Moose.query_DB('N2CB', wl=(x.min(), x.max()))
model = lmfit.Model(Moose.model_for_fit, sim_db = db)
params = lmfit.create_params(**Moose.default_params)
params['T_rot'].max = 8000
params['T_vib'].max = 8000
params['A'].vary = False
# start a cluster on the local machine and display client info
client = distributed.Client()
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import Moose
import lmfit
import distributed
from functools import partial
import plotly.graph_objects as go
import panel as pn
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
rng = np.random.default_rng(0)
pn.extension('plotly')
x = np.linspace(320,390,2000)
db = Moose.query_DB('N2CB', wl=(x.min(), x.max()))
model = lmfit.Model(Moose.model_for_fit, sim_db = db)
params = lmfit.create_params(**Moose.default_params)
params['T_rot'].max = 8000
params['T_vib'].max = 8000
params['A'].vary = False
# start a cluster on the local machine and display client info
client = distributed.Client()
In [7]:
Copied!
def create_variables(params:lmfit.Parameters) -> dict:
"""Creates some random values to simulate a spectrum, within defined bounds of the parameters"""
variables = {p:rng.normal(params[p].value, params[p].value*0.05) for p in params if params[p].vary is True}
return variables
def create_spectrum(x:np.array, model:lmfit.Model, noise=0.01,**kwargs) -> np.array:
"""Creates a spectrum with noise superimposed"""
spec = rng.normal(model.eval(x=x, params=params,**kwargs),noise)
return (spec-spec.min())/(spec.max()-spec.min())
variables = [create_variables(params) for i in range(20)]
spec = [create_spectrum(x,model,noise=rng.uniform(0.01,0.05),**var) for var in variables]
def create_variables(params:lmfit.Parameters) -> dict:
"""Creates some random values to simulate a spectrum, within defined bounds of the parameters"""
variables = {p:rng.normal(params[p].value, params[p].value*0.05) for p in params if params[p].vary is True}
return variables
def create_spectrum(x:np.array, model:lmfit.Model, noise=0.01,**kwargs) -> np.array:
"""Creates a spectrum with noise superimposed"""
spec = rng.normal(model.eval(x=x, params=params,**kwargs),noise)
return (spec-spec.min())/(spec.max()-spec.min())
variables = [create_variables(params) for i in range(20)]
spec = [create_spectrum(x,model,noise=rng.uniform(0.01,0.05),**var) for var in variables]
In [8]:
Copied!
futures = client.map(partial(model.fit, x=x, params=params), spec)
futures = client.map(partial(model.fit, x=x, params=params), spec)
In [9]:
Copied!
results = client.gather(futures)
results = client.gather(futures)
In [10]:
Copied!
input_view_index = pn.widgets.Spinner(name='Index', start=0, end=len(results)-1, value=0)
output_params = {k:pn.widgets.FloatInput(name=k, value=results[0].params[k].value, disabled=True) for k in params if params[k].vary is True}
fig = go.Figure()
fig.add_scatter(x=x, y=spec[0], name='Spectrum')
fig.add_scatter(x=x, y=results[0].eval(x=x), name='Fit')
fig.add_scatter(x=x, y=results[0].residual, name='Residual')
fig.update_xaxes(title_text='Wavelength (nm)')
fig.update_yaxes(title_text='Norm. I. (a.u.)')
layout = pn.Column(pn.pane.Plotly(fig, sizing_mode='stretch_width'), input_view_index, pn.FlexBox(*output_params.values()))
def cb_update_view(*args):
with fig.batch_update():
fig.data[0].x = x
fig.data[0].y =spec[input_view_index.value]
fig.data[1].x = x
fig.data[1].y = results[input_view_index.value].eval(x=x)
fig.data[2].x = x
fig.data[2].y = results[input_view_index.value].residual
for p in results[input_view_index.value].params:
if params[p].vary is True:
output_params[p].value = results[input_view_index.value].params[p].value
input_view_index.param.watch(cb_update_view, 'value_throttled')
layout
input_view_index = pn.widgets.Spinner(name='Index', start=0, end=len(results)-1, value=0)
output_params = {k:pn.widgets.FloatInput(name=k, value=results[0].params[k].value, disabled=True) for k in params if params[k].vary is True}
fig = go.Figure()
fig.add_scatter(x=x, y=spec[0], name='Spectrum')
fig.add_scatter(x=x, y=results[0].eval(x=x), name='Fit')
fig.add_scatter(x=x, y=results[0].residual, name='Residual')
fig.update_xaxes(title_text='Wavelength (nm)')
fig.update_yaxes(title_text='Norm. I. (a.u.)')
layout = pn.Column(pn.pane.Plotly(fig, sizing_mode='stretch_width'), input_view_index, pn.FlexBox(*output_params.values()))
def cb_update_view(*args):
with fig.batch_update():
fig.data[0].x = x
fig.data[0].y =spec[input_view_index.value]
fig.data[1].x = x
fig.data[1].y = results[input_view_index.value].eval(x=x)
fig.data[2].x = x
fig.data[2].y = results[input_view_index.value].residual
for p in results[input_view_index.value].params:
if params[p].vary is True:
output_params[p].value = results[input_view_index.value].params[p].value
input_view_index.param.watch(cb_update_view, 'value_throttled')
layout
Out[10]:
In [ ]:
Copied!