from researchutils import files
import chainer.serializers
[docs]def save_model(path, model):
"""
Save model as an npz file to given path
Parameters
-------
path : string
path of the model to be saved
model : chainer.Link
model to save parameters
Raises
-------
ValueError
File already exists
"""
if files.file_exists(path):
raise ValueError('File already exists in {}'.format(path))
chainer.serializers.save_npz(path, model)
[docs]def load_model(path, model):
"""
Load model from the npz file of given path
Parameters
------
path : string
path of the saved model
Returns
------
model : chainer.Link
model with parameters initialized from loaded file
if the file does not exist, then will return given model without any changes
"""
if not files.file_exists(path):
return model
return chainer.serializers.load_npz(path, model)
[docs]def load_snapshot(path, trainer):
"""
Load snapshot from the npz file of given path
Parameters
------
path : string
path of the saved model
Returns
------
trainer : chainer.Trainer
trainer with associated objects initialized with loaded file
if the file does not exist, then will return given trainer without any changes
"""
if not files.file_exists(path):
return trainer
return chainer.serializers.load_npz(path, trainer)