Can I hold 2 batches of dlnetwork gradients and update network parameters in 1 operation?

Illustration
ternamichel_231 - 2021-04-13T10:22:51+00:00
Question: Can I hold 2 batches of dlnetwork gradients and update network parameters in 1 operation?

Due to the limitation of GPU memory, a deeplearning network can't learn like 16 samples in a batch. So can I compute the gradients for a batch of 8 samples, and update the network gradients with 2 batches' gradients? If I compute the gradients of a deeplearning network by [gradients,state,loss] = dlfeval(@modelGradient,dlNet,xTrain,yTrain); So after 2 batches, I get gradients1, gradients2, state1, state2, loss1, and loss2. For my instant opinion, I think the total gradients should be the mean of gradients1 and gradients2. But how can I compute the state values? Is it also the mean of state1 and state2? 

Expert Answer

Profile picture of Neeta Dsouza Neeta Dsouza answered . 2025-11-20

Yes, absolutely, just sum the gradients until your batch size is the size you want, then update the model. The principle is exactly the same is training a model on multiple GPUs (or CPUs).
 
The State update depends on the State. This example shows you how to aggregate batch norm state. The function aggregateState is of interest here. Instead of using gplus you would just be aggregating over your 'sub'-iterations.
function state = aggregateState(state,factor)

    numrows = size(state,1);
    
    for j = 1:numrows
        isBatchNormalizationState = state.Parameter(j) =="TrainedMean"...
            && state.Parameter(j+1) =="TrainedVariance"...
            && state.Layer(j) == state.Layer(j+1);
        
        if isBatchNormalizationState
            meanVal = state.Value{j};
            varVal = state.Value{j+1};
            
            % Calculate combined mean
            combinedMean = gplus(factor*meanVal);
                   
            % Caclulate combined variance terms to sum
            combinedVarTerm = factor.*(varVal + (meanVal - combinedMean).^2);        
            
            % Update state
            state.Value(j) = {combinedMean};
            state.Value(j+1) = {gplus(combinedVarTerm)};
           
        end
    end
end


Not satisfied with the answer ?? ASK NOW

Get a Free Consultation or a Sample Assignment Review!