Class: Classifier::LogisticRegression
- Includes:
- Streaming, Mutex_m
- Defined in:
- lib/classifier/logistic_regression.rb
Overview
Logistic Regression (MaxEnt) classifier using Stochastic Gradient Descent. Often provides better accuracy than Naive Bayes while remaining fast and interpretable.
Example:
classifier = Classifier::LogisticRegression.new(:spam, :ham)
classifier.train(spam: ["Buy now!", "Free money!!!"])
classifier.train(ham: ["Meeting tomorrow", "Project update"])
classifier.classify("Claim your prize!") # => "Spam"
classifier.probabilities("Claim your prize!") # => {"Spam" => 0.92, "Ham" => 0.08}
Constant Summary collapse
- DEFAULT_LEARNING_RATE =
0.1- DEFAULT_REGULARIZATION =
0.01- DEFAULT_MAX_ITERATIONS =
100- DEFAULT_TOLERANCE =
1e-4
Constants included from Streaming
Instance Attribute Summary collapse
-
#storage ⇒ Object
Returns the value of attribute storage.
Class Method Summary collapse
-
.from_json(json) ⇒ Object
Loads a classifier from a JSON string or Hash.
-
.load(storage:) ⇒ Object
Loads a classifier from the configured storage.
-
.load_checkpoint(storage:, checkpoint_id:) ⇒ Object
Loads a classifier from a checkpoint.
-
.load_from_file(path) ⇒ Object
Loads a classifier from a file.
Instance Method Summary collapse
-
#add_category(category) ⇒ Object
Adds a new category to the classifier.
-
#as_json(_options = nil) ⇒ Object
Returns a hash representation of the classifier state.
-
#categories ⇒ Object
Returns the list of categories.
-
#classifications(text) ⇒ Object
Returns log-odds scores for each category (before softmax).
-
#classify(text) ⇒ Object
Returns the best matching category for the provided text.
-
#dirty? ⇒ Boolean
Returns true if there are unsaved changes.
-
#fit ⇒ Object
Fits the model to all accumulated training data.
-
#fitted? ⇒ Boolean
Returns true if the model has been fitted.
-
#initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE, regularization: DEFAULT_REGULARIZATION, max_iterations: DEFAULT_MAX_ITERATIONS, tolerance: DEFAULT_TOLERANCE, min_word_length: Classifier.config.min_word_length) ⇒ LogisticRegression
constructor
Creates a new Logistic Regression classifier with the specified categories.
-
#marshal_dump ⇒ Object
Custom marshal serialization to exclude mutex state.
-
#marshal_load(data) ⇒ Object
Custom marshal deserialization to recreate mutex.
-
#method_missing(name, *args) ⇒ Object
Provides training methods for the categories.
-
#probabilities(text) ⇒ Object
Returns probability distribution across all categories.
-
#reload ⇒ Object
Reloads the classifier from storage, raising if there are unsaved changes.
-
#reload! ⇒ Object
Force reloads the classifier from storage, discarding any unsaved changes.
- #respond_to_missing?(name, include_private = false) ⇒ Boolean
-
#save ⇒ Object
Saves the classifier to the configured storage.
-
#save_to_file(path) ⇒ Object
Saves the classifier state to a file.
-
#to_json(_options = nil) ⇒ Object
Serializes the classifier state to a JSON string.
-
#train(category = nil, text = nil, **categories) ⇒ Object
Trains the classifier with text for a category.
-
#train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block) ⇒ Object
Trains the classifier with an array of documents in batches.
-
#train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) ⇒ Object
Trains the classifier from an IO stream.
-
#weights(category, limit: nil) ⇒ Object
Returns feature weights for a category, sorted by importance.
Methods included from Streaming
#delete_checkpoint, #list_checkpoints, #save_checkpoint
Constructor Details
#initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE, regularization: DEFAULT_REGULARIZATION, max_iterations: DEFAULT_MAX_ITERATIONS, tolerance: DEFAULT_TOLERANCE, min_word_length: Classifier.config.min_word_length) ⇒ LogisticRegression
Creates a new Logistic Regression classifier with the specified categories.
classifier = Classifier::LogisticRegression.new(:spam, :ham)
classifier = Classifier::LogisticRegression.new('Positive', 'Negative', 'Neutral')
classifier = Classifier::LogisticRegression.new(['Positive', 'Negative', 'Neutral'])
Options:
-
learning_rate: Step size for gradient descent (default: 0.1)
-
regularization: L2 regularization strength (default: 0.01)
-
max_iterations: Maximum training iterations (default: 100)
-
tolerance: Convergence threshold (default: 1e-4)
-
min_word_length: Minimum word length filter in tokenization
rubocop:disable Metrics/ParameterLists
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# File 'lib/classifier/logistic_regression.rb', line 62 def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE, regularization: DEFAULT_REGULARIZATION, max_iterations: DEFAULT_MAX_ITERATIONS, tolerance: DEFAULT_TOLERANCE, min_word_length: Classifier.config.min_word_length) super() categories = categories.flatten @categories = categories.map { |c| c.to_s.prepare_category_name } @weights = @categories.to_h { |c| [c, {}] } @bias = @categories.to_h { |c| [c, 0.0] } @vocabulary = {} @training_data = [] @learning_rate = learning_rate @regularization = regularization @max_iterations = max_iterations @tolerance = tolerance @fitted = false @dirty = false @storage = nil @min_word_length = min_word_length end |
Dynamic Method Handling
This class handles dynamic methods through the method_missing method
#method_missing(name, *args) ⇒ Object
Provides training methods for the categories.
classifier.train_spam "Buy now!"
216 217 218 219 220 221 222 223 224 |
# File 'lib/classifier/logistic_regression.rb', line 216 def method_missing(name, *args) category_match = name.to_s.match(/train_(\w+)/) return super unless category_match category = category_match[1].to_s.prepare_category_name raise StandardError, "No such category: #{category}" unless @categories.include?(category) args.each { |text| train(category, text) } end |
Instance Attribute Details
#storage ⇒ Object
Returns the value of attribute storage.
39 40 41 |
# File 'lib/classifier/logistic_regression.rb', line 39 def storage @storage end |
Class Method Details
.from_json(json) ⇒ Object
Loads a classifier from a JSON string or Hash.
263 264 265 266 267 268 269 270 271 |
# File 'lib/classifier/logistic_regression.rb', line 263 def self.from_json(json) data = json.is_a?(String) ? JSON.parse(json) : json raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'logistic_regression' categories = data['categories'].map(&:to_sym) instance = allocate instance.send(:restore_state, data, categories) instance end |
.load(storage:) ⇒ Object
Loads a classifier from the configured storage.
295 296 297 298 299 300 301 302 |
# File 'lib/classifier/logistic_regression.rb', line 295 def self.load(storage:) data = storage.read raise StorageError, 'No saved state found' unless data instance = from_json(data) instance.storage = storage instance end |
.load_checkpoint(storage:, checkpoint_id:) ⇒ Object
Loads a classifier from a checkpoint.
364 365 366 367 368 369 370 371 372 373 374 375 376 |
# File 'lib/classifier/logistic_regression.rb', line 364 def self.load_checkpoint(storage:, checkpoint_id:) raise ArgumentError, 'Storage must be File storage for checkpoints' unless storage.is_a?(Storage::File) dir = File.dirname(storage.path) base = File.basename(storage.path, '.*') ext = File.extname(storage.path) checkpoint_path = File.join(dir, "#{base}_checkpoint_#{checkpoint_id}#{ext}") checkpoint_storage = Storage::File.new(path: checkpoint_path) instance = load(storage: checkpoint_storage) instance.storage = storage instance end |
.load_from_file(path) ⇒ Object
Loads a classifier from a file.
307 308 309 |
# File 'lib/classifier/logistic_regression.rb', line 307 def self.load_from_file(path) from_json(File.read(path)) end |
Instance Method Details
#add_category(category) ⇒ Object
Adds a new category to the classifier. Allows dynamic category creation for CLI and incremental training.
187 188 189 190 191 192 193 194 195 196 197 198 |
# File 'lib/classifier/logistic_regression.rb', line 187 def add_category(category) cat = category.to_s.prepare_category_name synchronize do return if @categories.include?(cat) @categories << cat @weights[cat] = {} @bias[cat] = 0.0 @fitted = false @dirty = true end end |
#as_json(_options = nil) ⇒ Object
Returns a hash representation of the classifier state. Does NOT auto-fit; saves current state including unfitted models.
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
# File 'lib/classifier/logistic_regression.rb', line 235 def as_json( = nil) { version: 1, type: 'logistic_regression', categories: @categories.map(&:to_s), weights: @weights.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) }, bias: @bias.transform_keys(&:to_s), vocabulary: @vocabulary.keys.map(&:to_s), training_data: @training_data.map { |d| { category: d[:category].to_s, features: d[:features].transform_keys(&:to_s) } }, learning_rate: @learning_rate, regularization: @regularization, max_iterations: @max_iterations, tolerance: @tolerance, fitted: @fitted, min_word_length: @min_word_length } end |
#categories ⇒ Object
Returns the list of categories.
179 180 181 |
# File 'lib/classifier/logistic_regression.rb', line 179 def categories synchronize { @categories.map(&:to_s) } end |
#classifications(text) ⇒ Object
Returns log-odds scores for each category (before softmax). Raises NotFittedError if model has not been fitted.
149 150 151 152 153 154 155 156 |
# File 'lib/classifier/logistic_regression.rb', line 149 def classifications(text) raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted features = text.word_hash(@min_word_length) synchronize do compute_scores(features).transform_keys(&:to_s) end end |
#classify(text) ⇒ Object
Returns the best matching category for the provided text.
classifier.classify("Buy now!") # => "Spam"
120 121 122 123 124 125 126 |
# File 'lib/classifier/logistic_regression.rb', line 120 def classify(text) probs = probabilities(text) best = probs.max_by { |_, v| v } raise StandardError, 'No classifications available' unless best best.first end |
#dirty? ⇒ Boolean
Returns true if there are unsaved changes.
210 211 212 |
# File 'lib/classifier/logistic_regression.rb', line 210 def dirty? @dirty end |
#fit ⇒ Object
Fits the model to all accumulated training data. Called automatically during classify/probabilities if not already fitted.
103 104 105 106 107 108 109 110 111 112 113 |
# File 'lib/classifier/logistic_regression.rb', line 103 def fit synchronize do return self if @training_data.empty? raise ArgumentError, 'At least two categories required for fitting' if @categories.size < 2 optimize_weights @fitted = true @dirty = false end self end |
#fitted? ⇒ Boolean
Returns true if the model has been fitted.
203 204 205 |
# File 'lib/classifier/logistic_regression.rb', line 203 def fitted? @fitted end |
#marshal_dump ⇒ Object
Custom marshal serialization to exclude mutex state.
343 344 345 346 347 |
# File 'lib/classifier/logistic_regression.rb', line 343 def marshal_dump fit unless @fitted [@categories, @weights, @bias, @vocabulary, @learning_rate, @regularization, @max_iterations, @tolerance, @fitted, @min_word_length] end |
#marshal_load(data) ⇒ Object
Custom marshal deserialization to recreate mutex.
352 353 354 355 356 357 358 359 |
# File 'lib/classifier/logistic_regression.rb', line 352 def marshal_load(data) mu_initialize @categories, @weights, @bias, @vocabulary, @learning_rate, @regularization, @max_iterations, @tolerance, @fitted, @min_word_length = data @training_data = [] @dirty = false @storage = nil end |
#probabilities(text) ⇒ Object
Returns probability distribution across all categories. Probabilities are well-calibrated (unlike Naive Bayes). Raises NotFittedError if model has not been fitted.
classifier.probabilities("Buy now!")
# => {"Spam" => 0.92, "Ham" => 0.08}
136 137 138 139 140 141 142 143 |
# File 'lib/classifier/logistic_regression.rb', line 136 def probabilities(text) raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted features = text.word_hash(@min_word_length) synchronize do softmax(compute_scores(features)) end end |
#reload ⇒ Object
Reloads the classifier from storage, raising if there are unsaved changes.
314 315 316 317 318 319 320 321 322 323 324 |
# File 'lib/classifier/logistic_regression.rb', line 314 def reload raise ArgumentError, 'No storage configured' unless storage raise UnsavedChangesError, 'Unsaved changes would be lost. Call save first or use reload!' if @dirty data = storage.read raise StorageError, 'No saved state found' unless data restore_from_json(data) @dirty = false self end |
#reload! ⇒ Object
Force reloads the classifier from storage, discarding any unsaved changes.
329 330 331 332 333 334 335 336 337 338 |
# File 'lib/classifier/logistic_regression.rb', line 329 def reload! raise ArgumentError, 'No storage configured' unless storage data = storage.read raise StorageError, 'No saved state found' unless data restore_from_json(data) @dirty = false self end |
#respond_to_missing?(name, include_private = false) ⇒ Boolean
227 228 229 |
# File 'lib/classifier/logistic_regression.rb', line 227 def respond_to_missing?(name, include_private = false) !!(name.to_s =~ /train_(\w+)/) || super end |
#save ⇒ Object
Saves the classifier to the configured storage.
276 277 278 279 280 281 |
# File 'lib/classifier/logistic_regression.rb', line 276 def save raise ArgumentError, 'No storage configured' unless storage storage.write(to_json) @dirty = false end |
#save_to_file(path) ⇒ Object
Saves the classifier state to a file.
286 287 288 289 290 |
# File 'lib/classifier/logistic_regression.rb', line 286 def save_to_file(path) result = File.write(path, to_json) @dirty = false result end |
#to_json(_options = nil) ⇒ Object
Serializes the classifier state to a JSON string.
256 257 258 |
# File 'lib/classifier/logistic_regression.rb', line 256 def to_json( = nil) JSON.generate(as_json) end |
#train(category = nil, text = nil, **categories) ⇒ Object
Trains the classifier with text for a category.
classifier.train(spam: "Buy now!", ham: ["Hello", "Meeting tomorrow"])
classifier.train(:spam, "legacy positional API")
91 92 93 94 95 96 97 |
# File 'lib/classifier/logistic_regression.rb', line 91 def train(category = nil, text = nil, **categories) return train_single(category, text) if category && text categories.each do |cat, texts| (texts.is_a?(Array) ? texts : [texts]).each { |t| train_single(cat, t) } end end |
#train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block) ⇒ Object
Trains the classifier with an array of documents in batches. Note: The model is NOT automatically fitted after batch training. Call #fit to train the model after adding all data.
417 418 419 420 421 422 423 424 425 |
# File 'lib/classifier/logistic_regression.rb', line 417 def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block) if category && documents train_batch_for_category(category, documents, batch_size: batch_size, &block) else categories.each do |cat, docs| train_batch_for_category(cat, Array(docs), batch_size: batch_size, &block) end end end |
#train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories) ⇒ Object
Trains the classifier from an IO stream. Each line in the stream is treated as a separate document. Note: The model is NOT automatically fitted after streaming. Call #fit to train the model after adding all data.
394 395 396 397 398 399 400 401 402 |
# File 'lib/classifier/logistic_regression.rb', line 394 def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &) raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty? raise ArgumentError, 'Provide both category and io, or use keyword arguments' if [category, io].one?(&:nil?) pairs = category && io ? { category => io } : categories pairs.each do |cat, stream| stream_train_category(cat, stream, batch_size:, &) end end |
#weights(category, limit: nil) ⇒ Object
Returns feature weights for a category, sorted by importance. Positive weights indicate the feature supports the category.
classifier.weights(:spam)
# => {:free => 2.3, :buy => 1.8, :money => 1.5, ...}
165 166 167 168 169 170 171 172 173 174 |
# File 'lib/classifier/logistic_regression.rb', line 165 def weights(category, limit: nil) fit unless @fitted cat = category.to_s.prepare_category_name raise StandardError, "No such category: #{cat}" unless @weights.key?(cat) sorted = @weights[cat].sort_by { |_, v| -v.abs } sorted = sorted.first(limit) if limit sorted.to_h end |