1 |
|
|
2 |
|
package net.sf.classifier4J.vector; |
3 |
|
|
4 |
|
import java.util.Arrays; |
5 |
|
import java.util.Map; |
6 |
|
import java.util.Set; |
7 |
|
|
8 |
|
import net.sf.classifier4J.AbstractCategorizedTrainableClassifier; |
9 |
|
import net.sf.classifier4J.ClassifierException; |
10 |
|
import net.sf.classifier4J.DefaultStopWordsProvider; |
11 |
|
import net.sf.classifier4J.DefaultTokenizer; |
12 |
|
import net.sf.classifier4J.IStopWordProvider; |
13 |
|
import net.sf.classifier4J.ITokenizer; |
14 |
|
import net.sf.classifier4J.Utilities; |
15 |
|
|
16 |
|
|
17 |
|
public class VectorClassifier extends AbstractCategorizedTrainableClassifier { |
18 |
2 |
public static double DEFAULT_VECTORCLASSIFIER_CUTOFF = 0.80d; |
19 |
|
|
20 |
|
|
21 |
6 |
private int numTermsInVector = 25; |
22 |
|
private ITokenizer tokenizer; |
23 |
|
private IStopWordProvider stopWordsProvider; |
24 |
|
private TermVectorStorage storage; |
25 |
|
|
26 |
6 |
public VectorClassifier() { |
27 |
6 |
tokenizer = new DefaultTokenizer(); |
28 |
6 |
stopWordsProvider = new DefaultStopWordsProvider(); |
29 |
6 |
storage = new HashMapTermVectorStorage(); |
30 |
|
|
31 |
6 |
setMatchCutoff(DEFAULT_VECTORCLASSIFIER_CUTOFF); |
32 |
6 |
} |
33 |
|
|
34 |
|
public VectorClassifier(TermVectorStorage storage) { |
35 |
6 |
this(); |
36 |
6 |
this.storage = storage; |
37 |
6 |
} |
38 |
|
|
39 |
|
|
40 |
|
|
41 |
|
|
42 |
|
public double classify(String category, String input) throws ClassifierException { |
43 |
|
|
44 |
|
|
45 |
14 |
Map wordFrequencies = Utilities.getWordFrequency(input, false, tokenizer, stopWordsProvider); |
46 |
|
|
47 |
14 |
TermVector tv = storage.getTermVector(category); |
48 |
14 |
if (tv == null) { |
49 |
2 |
return 0; |
50 |
|
} else { |
51 |
12 |
int[] inputValues = generateTermValuesVector(tv.getTerms(), wordFrequencies); |
52 |
|
|
53 |
12 |
return VectorUtils.cosineOfVectors(inputValues, tv.getValues()); |
54 |
|
} |
55 |
|
} |
56 |
|
|
57 |
|
|
58 |
|
|
59 |
|
|
60 |
|
|
61 |
|
public boolean isMatch(String category, String input) throws ClassifierException { |
62 |
6 |
return (getMatchCutoff() < classify(category, input)); |
63 |
|
} |
64 |
|
|
65 |
|
|
66 |
|
|
67 |
|
|
68 |
|
|
69 |
|
|
70 |
|
public void teachMatch(String category, String input) throws ClassifierException { |
71 |
|
|
72 |
6 |
Map wordFrequencies = Utilities.getWordFrequency(input, false, tokenizer, stopWordsProvider); |
73 |
|
|
74 |
|
|
75 |
6 |
Set mostFrequentWords = Utilities.getMostFrequentWords(numTermsInVector, wordFrequencies); |
76 |
|
|
77 |
6 |
String[] terms = (String[]) mostFrequentWords.toArray(new String[mostFrequentWords.size()]); |
78 |
6 |
Arrays.sort(terms); |
79 |
6 |
int[] values = generateTermValuesVector(terms, wordFrequencies); |
80 |
|
|
81 |
6 |
TermVector tv = new TermVector(terms, values); |
82 |
|
|
83 |
6 |
storage.addTermVector(category, tv); |
84 |
|
|
85 |
6 |
return; |
86 |
|
} |
87 |
|
|
88 |
|
|
89 |
|
|
90 |
|
|
91 |
|
|
92 |
|
|
93 |
|
protected int[] generateTermValuesVector(String[] terms, Map wordFrequencies) { |
94 |
18 |
int[] result = new class="keyword">int[terms.length]; |
95 |
108 |
for (int i = 0; i < terms.length; i++) { |
96 |
90 |
Integer value = (Integer)wordFrequencies.get(terms[i]); |
97 |
90 |
if (value == null) { |
98 |
48 |
result[i] = 0; |
99 |
|
} else { |
100 |
42 |
result[i] = value.intValue(); |
101 |
|
} |
102 |
|
|
103 |
|
} |
104 |
18 |
return result; |
105 |
|
} |
106 |
|
|
107 |
|
|
108 |
|
|
109 |
|
|
110 |
|
|
111 |
|
public void teachNonMatch(String category, String input) throws ClassifierException { |
112 |
0 |
return; |
113 |
|
} |
114 |
|
} |