How to define custom classification loss function

Illustration
himahgrft - 2022-04-04T12:07:46+00:00
Question: How to define custom classification loss function

I am currently trying to run a kfold cross validation on a decision tree with a custom classification loss function, as described here.   However, I don't understand how the C and S matrices, which are passed to the loss function, are helpful. 1. Under the link it says "C is an n-by-K logical matrix with rows indicating which class the corresponding observation belongs." So this is not predicted and hence a repetition of the input data?   2. The S matrix. "S is an n-by-K numeric matrix of classification scores". Why can I not simply use the predicted classifications instead of the scores?   To be more specific: I create a classification decision tree. Next, I use crossval to get a partitionedModel. Then, I calculate the validation accuracy by using kfoldLoss. Now, instead of using the in built 'classiferror' function, I would like to use my own classification loss function, e.g. matthews correlation coefficient. % create set of cross-validated classification model(s) from a classification model partitionedModel = crossval(trainedClassifier.ClassificationTree, 'KFold', 10); % Loss, by default the fraction of misclassified data, is a scalar and averaged over all folds validationAccuracy = 1 - kfoldLoss(partitionedModel, 'LossFun', 'classiferror'); Any help is greatly appreciated.

Expert Answer

Profile picture of John Michell John Michell answered . 2025-11-20

In case anybody else is looking for a solution. I used the crossval function to wrap the training of the decision tree. This way the implementation of other loss functions is straightforward.

 

function [trainedClassifier, qualityMeasures] = trainDTwCrossVal(data, predictorNames, MaxNumSplits)
% cross validation
numberOfFolds=5;
cp = cvpartition(data.typeBehavior,'k',numberOfFolds);  % creates a random partition for a stratified k-fold cross-validation
vals=crossval(@trainDT2, data, 'partition', cp);        % Loss estimate using cross validation
function testval = trainDT2(trainingData, testingData)      % nested function to train one DT with trainingData and test with testingData

Testval are quality measures of the prediction, derived from the confusion matrix, calculated inside the nested function to train the decision tree.

        % C=[TP FP
        %    FN TN]
        TP=C(1,1); FP=C(1,2); FN=C(2,1); TN=C(2,2);

        % Matthews correlation coefficient, worst value = -1, best value = 1
        if ( (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN) ) == 0
            MCC = 0;    % set MCC to zero, if the denominator is zero
        else
            MCC = (TP*TN - FP*FN) / ...
                sqrt( (TP+FP)*(TP+FN)*(TN+FP)*(TN+FN) );
        end

        accuracy=(TP+TN)/(TP+TN+FP+FN);     % accuracy, worst value = 0, best value = 1
        F1score=2*TP/(2*TP+FP+FN);          % F1 score, worst value = 0, best value = 1
        testval=[accuracy F1score MCC];

 


Not satisfied with the answer ?? ASK NOW

Get a Free Consultation or a Sample Assignment Review!