Fitting a Model¶
TODO:
basic example of fitting w/ numpy arrays
fitting changing the batch_size, epochs, lr, and shuffle
fitting w/ a custom optimizer and/or optimizer kwargs
fitting w/ or w/o flipout
passing a pandas dataframe
passing a DataGenerator to fit
Using multiple MC samples per batch¶
By default, ProbFlow uses only one Monte Carlo sample from the variational
posteriors per batch. However, you can use more by passing the n_mc
keyword argument to Model.fit()
. For example, to use 10 MC samples
during training:
model = pf.LinearRegression(x.shape[1])
model.fit(x, y, n_mc=10)
Using more MC samples will cause the fit to take longer, but the parameter optimization will be much more stable because the variance of the gradients will be less.
Note that Dense
modules, which use the flipout estimator by default,
will not use flipout when n_mc
> 1.
Backend graph optimization during fitting¶
By default, ProbFlow uses tf.function (for TensorFlow) or tracing (for PyTorch) to optimize the gradient computations during training. This generally makes training faster.
N = 1024
D = 7
randn = lambda *a: np.random.randn(*a).astype('float32')
x = randn(N, D)
w = randn(D, 1)
y = x@w + 0.1*randn(N, 1)
model = pf.LinearRegression(D)
model.fit(x, y)
# takes around 5s
But to disable autograph/tracing and use only eager execution during model
fitting, just pass the eager=True
kwarg to fit
. This takes longer but
can be more flexible in certiain situations that autograph/tracing can’t
handle.
model.fit(x, y, eager=True)
# takes around 28s
Warning
When inputs are DataFrames
or Series
it is not possible to use tracing
or tf.function
, so ProbFlow falls back on eager execution by defualt
when the input data are DataFrames
or Series
It’s much easier to debug models in eager mode, since you can step through your
own code using pdb, instead of
trying to step through the tensorflow or pytorch compilation functions. So, if
you’re getting an error when fitting your model and want to debug the problem,
try using eager=True
when calling fit
.
However, eager mode is used for all other ProbFlow functionality (e.g.
Model.predict()
, Model.predictive_sample()
,
Model.metric()
, Model.posterior_sample()
, etc). If you want an
optimized version of one of ProbFlow’s inference-time methods, for TensorFlow
you can wrap it in a tf.function
:
#model.fit(...)
@tf.function
def fast_predict(X):
return model.predict(X)
fast_predict(x_test)
Or for PyTorch, use torch.jit.trace
:
#model.fit(...)
def predict_fn(X):
return model.predict(X)
fast_predict = torch.jit.trace(predict_fn, (example_x))
fast_predict(x_test)