re training neural network from previous state using trainNetwork

Illustration
riyanjain - 2021-05-10T13:46:08+00:00
Question: re training neural network from previous state using trainNetwork

I am training a deep neural network , using the following matlab function: net = trainNetwork(XTrain,YTrain,layers,options);   could I use the trainNetwork command to retrain the network (not from scratch), using the last network state from previous training?   I am sharing some of the code: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% layers = [   imageInputLayer([1 Nin 5],"Name","imageinput") convolution2dLayer([1 3],32,"Name","conv","Padding","same")% %batchNormalizationLayer('Name','batchDown') tanhLayer("Name","tanh1") convolution2dLayer([1 3],32,"Name","conv","Padding","same","DilationFactor",[1 3])% %batchNormalizationLayer('Name','batchDown1') tanhLayer("Name","tanh2") convolution2dLayer([1 3],2,"Name","conv","Padding","same","DilationFactor",[1 9])%   regressionLayer];     options = trainingOptions('adam', ... 'InitialLearnRate',0.001, ... 'MaxEpochs',50000, ... 'ExecutionEnvironment','parallel',... 'Verbose',false, ... 'Plots','training-progress');     Net1 = trainNetwork(XTrain1,YTrain1,layers,options); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%   now I would like to train again getting Net2 using new data, and starting the training from Net1 stage. for example: Net2 = trainNetwork(XTrain2,YTrain2,layers,options); however its not clear to me how to start the train process for Net2 using Net1 stage.

Expert Answer

Profile picture of Kshitij Singh Kshitij Singh answered . 2025-11-20

You can do the following if Net1 is a SeriesNetwork:
 
 
Net2 = trainNetwork(XTrain2, YTrain2, Net1.Layers, options);
or if Net1 is a DAGNetwork:
 
 
Net2 = trainNetwork(XTrain2, YTrain2, layerGraph(Net1), options);
This will train using Net1 as the initial network.
 
If you would also like to prevent weights in certain layers from changing, you could use freezeWeights (see how to access it below). This function could be used to set the learning rates in those layers to zero. During training, trainNetwork does not update the parameters of the "frozen" layers.
 
edit(fullfile(matlabroot,'examples','nnet','main','freezeWeights.m'))


Not satisfied with the answer ?? ASK NOW

Get a Free Consultation or a Sample Assignment Review!