Divide a 4D array into training set and validation set for CNN (regression)

Illustration
Rebekah - 2021-04-07T12:22:43+00:00
Question: Divide a 4D array into training set and validation set for CNN (regression)

Hi everybody, I am trying to design a CNN for regression following this Matlab example. It uses a 4D array to store the images and vector to store the values associated to every picture. I am using this code to create a 4D array called 'database' that contains my images and a vector 'labels' that contains the values.       k = 1; %2cm for i = 1:1000 str = sprintf('images/2cm/%d.jpg', i); image_to_store = imread(str); database(:,:,1,k) = (image_to_store(:,:)); % images are in grey scale labels(k) = 2; k = k+1; end %20cm for i = 1:1000 str = sprintf('images/20cm/%d.jpg', i); image_to_store = imread(str); database(:,:,1,k) = (image_to_store(:,:)); labels(k) = 20; k = k+1; end % ... Now, I have my 4D  array and the vector, so I am trying to divide them into a Training Set and a Validation Set as suggested in the example linked. Can anyone please help me to understand how can I do that?

Expert Answer

Profile picture of John Williams John Williams answered . 2025-11-20

Hope this does what you wanted:

 

% Your data set:
% The first 1000 entries with labels 2cm,
% The second 1000 entries with labels 20cm,
database = rand(28,28,1,2000);


% percentage of training points = 70%, validation = 30%, test = 0% 
p=0.7;

% One way to divide the 2000 database entries 
[trainInd,valInd,testInd] = dividerand(2000,p,1-p,0);
trainDatabaseBad = database(:,:,:,trainInd);
valDatabaseBad = database(:,:,:,valInd);

size(trainDatabaseBad) % output: 28 28 1 1400
size(valDatabaseBad) % output: 28 28 1 600

% A better way to divide, which ensures that
% there is equal propotion of 2cm to 20cm samples in
% the training set, validation set, and the whole set
[trainInd1,valInd1,testInd1] = dividerand(1000,p,1-p,0);
[trainInd2,valInd2,testInd2] = dividerand(1000,p,1-p,0);
trainDatabase = cat(4, database(:,:,:,trainInd1), database(:,:,:,trainInd2));
valDatabase = cat(4, database(:,:,:,valInd1), database(:,:,:,valInd2));

size(trainDatabase) % output: 28 28 1 1400
size(valDatabase) % output: 28 28 1 600
Just a word of caution that holding data set in memory might not be optimal if the it is a large data set. If that's the case one can use data stores e.g. image data store. They make the above much simpler with splitEachLabel function, and help with things like data augmentation.


Not satisfied with the answer ?? ASK NOW

Get a Free Consultation or a Sample Assignment Review!