View Javadoc

1   
2   package net.sf.classifier4J.vector;
3   
4   import java.util.Arrays;
5   import java.util.Map;
6   import java.util.Set;
7   
8   import net.sf.classifier4J.AbstractCategorizedTrainableClassifier;
9   import net.sf.classifier4J.ClassifierException;
10  import net.sf.classifier4J.DefaultStopWordsProvider;
11  import net.sf.classifier4J.DefaultTokenizer;
12  import net.sf.classifier4J.IStopWordProvider;
13  import net.sf.classifier4J.ITokenizer;
14  import net.sf.classifier4J.Utilities;
15  
16  
17  public class VectorClassifier extends AbstractCategorizedTrainableClassifier {
18      public static double DEFAULT_VECTORCLASSIFIER_CUTOFF = 0.80d;
19      
20      
21      private int numTermsInVector = 25;
22      private ITokenizer tokenizer;
23      private IStopWordProvider stopWordsProvider;
24      private TermVectorStorage storage;    
25      
26      public VectorClassifier() {
27          tokenizer = new DefaultTokenizer();
28          stopWordsProvider = new DefaultStopWordsProvider();
29          storage = new HashMapTermVectorStorage();
30          
31          setMatchCutoff(DEFAULT_VECTORCLASSIFIER_CUTOFF);
32      }
33      
34      public VectorClassifier(TermVectorStorage storage) {
35          this();
36          this.storage = storage;
37      }
38      
39      /***
40       * @see net.sf.classifier4J.ICategorisedClassifier#classify(java.lang.String, java.lang.String)
41       */
42      public double classify(String category, String input) throws ClassifierException {
43          
44          // Create a map of the word frequency from the input
45          Map wordFrequencies = Utilities.getWordFrequency(input, false, tokenizer, stopWordsProvider);
46          
47          TermVector tv = storage.getTermVector(category);
48          if (tv == null) {
49              return 0;
50          } else {
51              int[] inputValues = generateTermValuesVector(tv.getTerms(), wordFrequencies);
52              
53              return VectorUtils.cosineOfVectors(inputValues, tv.getValues());
54          }        
55      }
56  
57  
58      /***
59       * @see net.sf.classifier4J.ICategorisedClassifier#isMatch(java.lang.String, java.lang.String)
60       */
61      public boolean isMatch(String category, String input) throws ClassifierException {
62          return (getMatchCutoff() < classify(category, input));
63      }
64  
65      
66  
67      /***
68       * @see net.sf.classifier4J.ITrainable#teachMatch(java.lang.String, java.lang.String)
69       */
70      public void teachMatch(String category, String input) throws ClassifierException {
71          // Create a map of the word frequency from the input
72          Map wordFrequencies = Utilities.getWordFrequency(input, false, tokenizer, stopWordsProvider);
73          
74          // get the numTermsInVector most used words in the input
75          Set mostFrequentWords = Utilities.getMostFrequentWords(numTermsInVector, wordFrequencies);
76  
77          String[] terms = (String[]) mostFrequentWords.toArray(new String[mostFrequentWords.size()]);
78          Arrays.sort(terms);
79          int[] values = generateTermValuesVector(terms, wordFrequencies);
80          
81          TermVector tv = new TermVector(terms, values);
82          
83          storage.addTermVector(category, tv);        
84          
85          return;
86      }
87  
88      /***
89       * @param terms
90       * @param wordFrequencies
91       * @return
92       */
93      protected int[] generateTermValuesVector(String[] terms, Map wordFrequencies) {
94          int[] result = new int[terms.length];
95          for (int i = 0; i < terms.length; i++) {
96              Integer value = (Integer)wordFrequencies.get(terms[i]);
97              if (value == null) {
98                  result[i] = 0;
99              } else {
100                 result[i] = value.intValue();
101             }
102             
103         }        
104         return result;
105     }
106 
107 
108     /***
109      * @see net.sf.classifier4J.ITrainable#teachNonMatch(java.lang.String, java.lang.String)
110      */
111     public void teachNonMatch(String category, String input) throws ClassifierException {
112         return; // this is not required for the VectorClassifier        
113     }
114 }