mindspore.train
SummaryRecord.
User can use SummaryRecord to dump the summary data, the summary is a series of operationsto collect data for analysis and visualization.
- class
mindspore.train.summary.
SummaryRecord
(log_dir, queue_max_size=0, flush_time=120, file_prefix='events', file_suffix='_MS', network=None)[source] - Summary log record.
SummaryRecord is used to record the summary value.The API will create an event file in a given directory and add summaries and events to it.
- Parameters
log_dir (str) – The log_dir is a directory location to save the summary.
queue_max_size (int) – The capacity of event queue.(reserved). Default: 0.
flush_time (int) – Frequency to flush the summaries to disk, the unit is second. Default: 120.
file_prefix (str) – The prefix of file. Default: “events”.
file_suffix (str) – The suffix of file. Default: “_MS”.
network (Cell) – Obtain a pipeline through network for saving graph summary. Default: None.
Raises
TypeError – If queue_max_size and flush_time is not int, or file_prefix and file_suffix is not str.
RuntimeError – If the log_dir can not be resolved to a canonicalized absolute pathname.
Examples
- Copy>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
- >>> file_prefix="xxx_", file_suffix="_yyy")
close
()[source]- Flush all events and close summary records.
Examples
- Copy>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
- >>> file_prefix="xxx_", file_suffix="_yyy")
- >>> summary_record.close()
flush
()[source]- Flush the event file to disk.
Call it to make sure that all pending events have been written to disk.
Examples
- Copy>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
- >>> file_prefix="xxx_", file_suffix="_yyy")
- >>> summary_record.flush()
Examples
- Copy>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
- >>> file_prefix="xxx_", file_suffix="_yyy")
- >>> print(summary_record.log_dir)
- Returns
-
String, the full path of log file.
record
(step, train_network=None)[source]Record the summary.
Examples
- Copy>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
- >>> file_prefix="xxx_", file_suffix="_yyy")
- >>> summary_record.record(step=2)
- Returns
-
bool, whether the record process is successful or not.
Callback related classes and functions.
- class
mindspore.train.callback.
Callback
[source] - Abstract base class used to build a callback function.
Callback function will execution some operating to the current step or epoch.
Examples
- Copy>>> class Print_info(Callback):
- >>> def step_end(self, run_context):
- >>> cb_params = run_context.original_args()
- >>> print(cb_params.cur_epoch_num)
- >>> print(cb_params.cur_step_num)
- >>>
- >>> print_cb = Print_info()
- >>> model.train(epoch, dataset, callback=print_cb)
begin
(run_context)[source]Called once before the network executing.
- Parameters
- run_context (RunContext) – Include some information of the model.
end
(run_context)[source]Called once after network training.
- Parameters
- run_context (RunContext) – Include some information of the model.
epochbegin
(_run_context)[source]Called before each epoch beginning.
- Parameters
- run_context (RunContext) – Include some information of the model.
epochend
(_run_context)[source]Called after each epoch finished.
- Parameters
- run_context (RunContext) – Include some information of the model.
stepbegin
(_run_context)[source]Called before each epoch beginning.
- Parameters
- run_context (RunContext) – Include some information of the model.
stepend
(_run_context)[source]Called after each step finished.
- Parameters
- run_context (RunContext) – Include some information of the model.
- class
mindspore.train.callback.
LossMonitor
(per_print_times=1)[source] - Monitor the loss in training.
If the loss is NAN or INF, it will terminate training.
Note
If per_print_times is 0 do not print loss.
- Parameters
per_print_times (int) – Print loss every times. Default: 1.
Raises
- ValueError – If print_step is not int or less than zero.
- class
mindspore.train.callback.
ModelCheckpoint
(prefix='CKP', directory=None, config=None)[source] - The checkpoint callback class.
It is called to combine with train process and save the model and network parameters after traning.
- Parameters
prefix (str) – Checkpoint files names prefix. Default: “CKP”.
directory (str) – Lolder path into which checkpoint files will be saved. Default: None.
config (CheckpointConfig) – Checkpoint strategy config. Default: None.
Raises
ValueError – If the prefix is invalid.
TypeError – If the config is not CheckpointConfig type.
end
(run_context)[source]Save the last checkpoint after training finished.
- Parameters
- run_context (RunContext) – Context of the train running.
Return the latest checkpoint path and file name.
stepend
(_run_context)[source]Save the checkpoint at the end of step.
- Parameters
- run_context (RunContext) – Context of the train running.
- class
mindspore.train.callback.
SummaryStep
(summary, flush_step=10)[source] The summary callback class.
- Parameters
summary (Object) – Summary recode object.
flush_step (int) – Number of interval steps to execute. Default: 10.
stepend
(_run_context)[source]Save summary.
- Parameters
- run_context (RunContext) – Context of the train running.
- class
mindspore.train.callback.
CheckpointConfig
(save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0)[source] The config for model checkpoint.
- Parameters
save_checkpoint_steps (int) – Steps to save checkpoint. Default: 1.
save_checkpoint_seconds (int) – Seconds to save checkpoint. Default: 0.Can’t be used with save_checkpoint_steps at the same time.
keep_checkpoint_max (int) – Maximum step to save checkpoint. Default: 5.
keep_checkpoint_per_n_minutes (int) – Keep one checkpoint every n minutes. Default: 0.Can’t be used with keep_checkpoint_max at the same time.
Raises
- ValueError – If the input_param is None or 0.
Examples
- Copy>>> config = CheckpointConfig()
- >>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config)
- >>> model.train(10, dataset, callbacks=ckpoint_cb)
get_checkpoint_policy
()[source]Get the policy of checkpoint.
Get the value of _keep_checkpoint_max.
Get the value of _keep_checkpoint_per_n_minutes.
Get the value of _save_checkpoint_seconds.
- Get the value of _save_checkpoint_steps.
- class
mindspore.train.callback.
RunContext
(original_args)[source] - Provides information about the model.
Run call being made. Provides information about original request to model function.callback objects can stop the loop by calling request_stop() of run_context.
- Parameters
original_args (dict) – Holding the related information of model etc.
get_stop_requested
()[source]Returns whether a stop is requested or not.
- Returns
- bool, if true, model.train() stops iterations.
original_args
()[source]Get the _original_args object.
- Returns
- _InternalCallbackParam, a object holding the original arguments of model.
request_stop
()[source]- Sets stop requested during training.
Callbacks can use this function to request stop of iterations.model.train() checks whether this is called or not.
Model and parameters serialization.
mindspore.train.serialization.
savecheckpoint
(_parameter_list, ckpoint_file_name)[source]Saves checkpoint info to a specified file.
- Parameters
Raises
- RuntimeError – Failed to save the Checkpoint file.
mindspore.train.serialization.
loadcheckpoint
(_ckpoint_file_name, net=None)[source]Loads checkpoint info from a specified file.
- Parameters
Returns
Dict, key is parameter name, value is a Parameter.
Raises
- ValueError – Checkpoint file is incorrect.
mindspore.train.serialization.
loadparam_into_net
(_net, parameter_dict)[source]Loads parameters into network.
mindspore.train.serialization.
export
(net, *inputs, file_name, file_format='GEIR')[source]Exports MindSpore predict model to file in specified format.
MindSpore currently supports ‘GEIR’, ‘ONNX’ and ‘LITE’ format for exported model.
-
- GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
-
Ascend model.
-
ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
-
LITE: Huawei model format for mobile.
Auto mixed precision.
mindspore.train.amp.
buildtrain_network
(_network, optimizer, loss_fn=None, level='O0', **kwargs)[source]Build the mixed precision training cell automatically.
Supports [O0, O2]. Default: “O0”.
-
O0: Do not change.
-
O2: Cast network to float16, keep batchnorm and loss_fn (if set) run in float32,using dynamic loss scale.
-
cast_model_type (mindspore.dtype
) – Supports mstype.float16 or mstype.float32.If set to mstype.float16, use float16 mode to train. If set, overwrite the level setting.
-
keep_batchnorm_fp32 (bool) – Keep Batchnorm run in float32. If set, overwrite the level setting.
-
loss_scale_manager (Union__[None, LossScaleManager]) – If None, not scale the loss, or elsescale the loss by LossScaleManager. If set, overwrite the level setting.
Loss scale manager abstract class.
- class
mindspore.train.loss_scale_manager.
LossScaleManager
[source] Loss scale manager abstract class.
- class
mindspore.train.lossscale_manager.
FixedLossScaleManager
(_loss_scale=128.0, drop_overflow_update=True)[source] Fixed loss-scale manager.
Examples
- Copy>>> loss_scale_manager = FixedLossScaleManager()
- >>> model = Model(net, loss_scale_manager=loss_scale_manager)
get_drop_overflow_update
()[source]Get the flag whether to drop optimizer update when there is overflow happened
get_loss_scale
()[source]Get loss scale value.
get_update_cell
()[source]Returns the cell for TrainOneStepWithLossScaleCell
updateloss_scale
(_overflow)[source]Update loss scale value.
- Parameters
- overflow (bool) – Whether it overflows.
- class
mindspore.train.lossscale_manager.
DynamicLossScaleManager
(_init_loss_scale=16777216, scale_factor=2, scale_window=2000)[source] Dynamic loss-scale manager.
Examples
- Copy>>> loss_scale_manager = DynamicLossScaleManager()
- >>> model = Model(net, loss_scale_manager=loss_scale_manager)
get_drop_overflow_update
()[source]Get the flag whether to drop optimizer update when there is overflow happened
get_loss_scale
()[source]Get loss scale value.
get_update_cell
()[source]Returns the cell for TrainOneStepWithLossScaleCell
updateloss_scale
(_overflow)[source]Update loss scale value.
- Parameters
- overflow – Boolean. Whether it overflows.