1 function [estimated, scores, evals] = ...
2 cross_validation_module(features, labels, participants_vect, parameters)
3 %Mohammad Soleymani October 2012
5 % features : [N*F] database to cross validate, N number of samples, F
7 % labels : [N*1] list of label of each sample, labels ranged from 1 to nbClasses
8 % parameters.nbClasses : [1] number of classes in the dataset
9 % parameters.classifier : [string] type of classifier to use :
11 % % FisherDiag : LDA with forced diagonal covariance matrix
12 % SVMLin, SVMPoly, SVMSigmoid, SVMRbf : SVM with different kernels
13 % KNN : K-nearest neigbours
14 % parameters.nbFolds : [1] number of folds
for the k fold cros-validation
15 % parameters.fusion: fusio method
'sum' 'product'
16 % parameters.cross_validation:
'leave-one-out',
'one-participant-out'
17 % parameters. : [1] parameter is the parameter to find by
using cross validation,
for
18 % example it will be the gamma value
for SVM
20 % evalMean : mean of the results of each fold
21 % evalStd : variance of the results of each fold
22 % listEval : list of results
for each validation step
24 parameters.seq_labels =
false;
25 if ~isfield(parameters,
'verbose')
26 parameters.verbose = true;
30 if ~isempty(participants_vect)
31 participants_codes = unique(participants_vect);
33 if ~isfield(parameters,'nbFolds')
34 parameters.nbFolds = 10;
39 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
40 nbTrials = size(features,1);
41 scores = zeros(nbTrials,parameters.nbClasses);
42 estimated = zeros(nbTrials,1);
47 if strcmp(parameters.cross_validation,'one-participant-out')
48 parameters.nbFolds = length(participants_codes);
49 fold = cell(parameters.nbFolds,1);
50 for k = 1:parameters.nbFolds
51 fold{k} = find(participants_vect == participants_codes(k));
53 elseif strcmp(parameters.cross_validation,
'leave-one-out')
54 parameters.nbFolds = nbTrials;
55 fold = cell(parameters.nbFolds,1);
60 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
61 % cross validations : classification
62 for k = 1:parameters.nbFolds
64 fprintf(
'fold %d of %d\n',k,parameters.nbFolds);
66 %determination des ensembles d
'apprentissages et de test
68 trainSet = setdiff(1:nbTrials,iTest);
69 xTrain = features(trainSet,:);
70 xTest = features(iTest,:);
71 labelsTrain = labels(trainSet);
72 [xTrain, xTest] = normalization_module(xTrain, xTest, parameters);
73 [~, xTrain, xTest] = feature_sel_module(xTrain, xTest, labelsTrain, parameters);
74 [estimated(iTest), scores(iTest,:)]= classif_module(xTrain, [],xTest, labelsTrain, [], parameters);
79 if parameters.nbClasses>1
80 [evals.classification_rate, evals.prec, evals.recall,evals.f1s] = classificationPrec(estimated, labels, parameters.nbClasses);
82 [evals] = regressionPrec(estimated, labels);