1
2 package net.sf.classifier4J.vector;
3
4 import java.util.Arrays;
5 import java.util.Map;
6 import java.util.Set;
7
8 import net.sf.classifier4J.AbstractCategorizedTrainableClassifier;
9 import net.sf.classifier4J.ClassifierException;
10 import net.sf.classifier4J.DefaultStopWordsProvider;
11 import net.sf.classifier4J.DefaultTokenizer;
12 import net.sf.classifier4J.IStopWordProvider;
13 import net.sf.classifier4J.ITokenizer;
14 import net.sf.classifier4J.Utilities;
15
16
17 public class VectorClassifier extends AbstractCategorizedTrainableClassifier {
18 public static double DEFAULT_VECTORCLASSIFIER_CUTOFF = 0.80d;
19
20
21 private int numTermsInVector = 25;
22 private ITokenizer tokenizer;
23 private IStopWordProvider stopWordsProvider;
24 private TermVectorStorage storage;
25
26 public VectorClassifier() {
27 tokenizer = new DefaultTokenizer();
28 stopWordsProvider = new DefaultStopWordsProvider();
29 storage = new HashMapTermVectorStorage();
30
31 setMatchCutoff(DEFAULT_VECTORCLASSIFIER_CUTOFF);
32 }
33
34 public VectorClassifier(TermVectorStorage storage) {
35 this();
36 this.storage = storage;
37 }
38
39 /***
40 * @see net.sf.classifier4J.ICategorisedClassifier#classify(java.lang.String, java.lang.String)
41 */
42 public double classify(String category, String input) throws ClassifierException {
43
44
45 Map wordFrequencies = Utilities.getWordFrequency(input, false, tokenizer, stopWordsProvider);
46
47 TermVector tv = storage.getTermVector(category);
48 if (tv == null) {
49 return 0;
50 } else {
51 int[] inputValues = generateTermValuesVector(tv.getTerms(), wordFrequencies);
52
53 return VectorUtils.cosineOfVectors(inputValues, tv.getValues());
54 }
55 }
56
57
58 /***
59 * @see net.sf.classifier4J.ICategorisedClassifier#isMatch(java.lang.String, java.lang.String)
60 */
61 public boolean isMatch(String category, String input) throws ClassifierException {
62 return (getMatchCutoff() < classify(category, input));
63 }
64
65
66
67 /***
68 * @see net.sf.classifier4J.ITrainable#teachMatch(java.lang.String, java.lang.String)
69 */
70 public void teachMatch(String category, String input) throws ClassifierException {
71
72 Map wordFrequencies = Utilities.getWordFrequency(input, false, tokenizer, stopWordsProvider);
73
74
75 Set mostFrequentWords = Utilities.getMostFrequentWords(numTermsInVector, wordFrequencies);
76
77 String[] terms = (String[]) mostFrequentWords.toArray(new String[mostFrequentWords.size()]);
78 Arrays.sort(terms);
79 int[] values = generateTermValuesVector(terms, wordFrequencies);
80
81 TermVector tv = new TermVector(terms, values);
82
83 storage.addTermVector(category, tv);
84
85 return;
86 }
87
88 /***
89 * @param terms
90 * @param wordFrequencies
91 * @return
92 */
93 protected int[] generateTermValuesVector(String[] terms, Map wordFrequencies) {
94 int[] result = new int[terms.length];
95 for (int i = 0; i < terms.length; i++) {
96 Integer value = (Integer)wordFrequencies.get(terms[i]);
97 if (value == null) {
98 result[i] = 0;
99 } else {
100 result[i] = value.intValue();
101 }
102
103 }
104 return result;
105 }
106
107
108 /***
109 * @see net.sf.classifier4J.ITrainable#teachNonMatch(java.lang.String, java.lang.String)
110 */
111 public void teachNonMatch(String category, String input) throws ClassifierException {
112 return;
113 }
114 }