------------------------------------------------------------------------- -- -- Train trip predictor for Magic Ink -- by Bret victor -- -- This is Lua source code. See http://lua.org for language info. -- I chose Lua because it's a wonderful little language that deserves -- wider recognition, it's easy to read, and making a user simulator -- without coroutines is masochism. -- ------------------------------------------------------------------------- -- Demo function main () testHumans() testNoise() end function testHumans () print( simulate(getNewLearningPredictor(), simple_human, 200) ) print( simulate(getNewLearningPredictor(), human_who_switches, 200) ) print( simulate(getNewLearningPredictor(), human_with_monday_and_thursday, 200) ) print( simulate(getNewLearningPredictor(), human_with_multicheck, 1000) ) end function testNoise () print (wrapString(human_with_noise.description)) for noise_level = 0, 2, 0.1 do human_with_noise.random_checks_per_real_check = noise_level local log = simulate(getNewLearningPredictor(), human_with_noise, 1000 * (1 + noise_level)) local misses = 0 local total = 0 for misses_this_week in string.gmatch(log, "Week [3-7]%d: (%d+)") do misses = misses + misses_this_week total = total + 10 end print(string.format("noise %0.1f: %d misses out of %d", noise_level, misses, total)) end end ------------------------------------------------------------------------- -- simulation driver function simulate (predictor, human, number_of_trials) local reporter = getNewReporter(predictor, human) for i = 1, number_of_trials or 200 do local date_that_human_looked, route_that_human_wants_to_see = human.lookAtTripPlanner() local predicted_route = predictor.predict(date_that_human_looked) reporter.report(date_that_human_looked, route_that_human_wants_to_see, predicted_route) predictor.learn(date_that_human_looked, route_that_human_wants_to_see) end return reporter.weekly_log end ------------------------------------------------------------------------- -- Learning predictor function getNewLearningPredictor () local learning_predictor = { description = "Learning predictor" } ------------------------------------------------------------------------- -- History management local history = {} local max_history_length = 200 local function addExampleToHistory (date, route) table.insert(history, 1, { date=date, route=route }) if table.getn(history) > max_history_length then table.remove(history) end end local function extendLastExample (date, route) history[1].date.ending_hour = date.hour end local function historyLength () return table.getn(history) end local function isHistoryEmpty() return historyLength() == 0 end local function foreachTrainingExample (f) for i, example in ipairs(history) do f(example.date, example.route, i == 1) end end ------------------------------------------------------------------------- -- Learning local extension_hours = 2/3 function learning_predictor.learn (date, route) if isHistoryEmpty() or date.route ~= history[1].route or date.day ~= history[1].date.day or date.hour > history[1].date.ending_hour + extension_hours then addExampleToHistory(date,route) else extendLastExample(date,route) end end ------------------------------------------------------------------------- -- Predicting local default_prediction = "Castro Valley to Del Norte" local last_example_weight_bonus = 1 local bandwidth_hours = 3 local wrong_day_weight_with_large_history = 1/6 local wrong_day_weight_with_small_history = 1/2 local function warp(v) return 0.5 + 0.5 * math.cos((1 - v) * math.pi) end local function getWeightForHour(example_date, current_date) local min_hour = example_date.starting_hour local max_hour = example_date.ending_hour local current_hour = current_date.hour if current_hour <= min_hour - bandwidth_hours then return 0 end if current_hour <= min_hour then return warp(1 - (min_hour - current_hour) / bandwidth_hours) end if current_hour <= max_hour then return 1 end if current_hour < max_hour + bandwidth_hours then return warp(1 - (current_hour - max_hour) / bandwidth_hours) end return 0 end local function getWeightForDay(example_date, current_date) if example_date.day_of_week == current_date.day_of_week then return 1 end if historyLength() > 20 then return wrong_day_weight_with_large_history else return wrong_day_weight_with_small_history end end local function getWeightForAge(example_date, current_date) local weeks_ago = math.floor((current_date.day - example_date.day) / 7) if weeks_ago >= 6 then return 0.1 end return 1 - (weeks_ago / 6) end function learning_predictor.predict (prediction_date) if isHistoryEmpty() then return default_prediction end if history[1].date.day == prediction_date.day and prediction_date.hour - history[1].date.ending_hour < 1/2 then return history[1].route end local votes = {} -- votes[route] = weight foreachTrainingExample(function (training_date, training_route, is_last_example) local weight = getWeightForHour(training_date, prediction_date) * getWeightForDay(training_date, prediction_date) * getWeightForAge(training_date, prediction_date) if is_last_example then weight = weight + last_example_weight_bonus end votes[training_route] = (votes[training_route] or 0) + weight end) local highest_weighted_route = getKeyWithMaxValue(votes) return highest_weighted_route end return learning_predictor end ------------------------------------------------------------------------- -- Reporter function getNewReporter (predictor, human) local previous_day_of_week = "Monday" local mispredictions_this_week = 0 local total_predictions_this_week = 0 local number_of_weeks = 0 local log_intro = string.format("Predictor: %s\nHuman: %s\n\n", wrapString(predictor.description), wrapString(human.description)) local reporter = { weekly_log = log_intro, detailed_log = log_intro } function reporter.report (date, route, predicted_route) local log_entry = string.format("%s: predicted: %s, actual: %s\n", tostring(date), predicted_route, route) reporter.detailed_log = reporter.detailed_log .. log_entry if string.find(route, "noise") then return end -- Don't count intentional noise. if date.day_of_week == "Monday" and previous_day_of_week ~= "Monday" then local miss_rate = math.ceil(100 * mispredictions_this_week / total_predictions_this_week) local log_entry = string.format("Week %d: %d mispredictions, %d%% miss rate\n", number_of_weeks + 1, mispredictions_this_week, miss_rate) reporter.weekly_log = reporter.weekly_log .. log_entry mispredictions_this_week = 0 total_predictions_this_week = 0 number_of_weeks = number_of_weeks + 1 end total_predictions_this_week = total_predictions_this_week + 1 if route ~= predicted_route then mispredictions_this_week = mispredictions_this_week + 1 end previous_day_of_week = date.day_of_week end return reporter end ------------------------------------------------------------------------- -- Humans simple_human = { description = "Every weekday, this person goes to SF at 9am and returns to Berkeley at 5pm." } simple_human.lookAtTripPlanner = coroutine.wrap(function () local function checkTrip (day, route, base_hour, deviation) local hour = gaussianRandom(base_hour, deviation) local date = makeDate(day, hour) coroutine.yield(date,route) end for day = 0, math.huge do local day_of_week = dayOfWeekFromDay(day) if day_of_week ~= "Saturday" and day_of_week ~= "Sunday" then checkTrip(day, "Berkeley to SF", 9, 1/2) checkTrip(day, "SF to Berkeley", 12 + 5, 1/2) end end end) human_who_switches = { description = "Every weekday, this person goes to SF at 9am and returns to Berkeley at 5pm. " .. "After 70 days, She then switches to going to/from San Leandro instead of SF." } human_who_switches.lookAtTripPlanner = coroutine.wrap(function () local function checkTrip (day, route, base_hour, deviation) local hour = gaussianRandom(base_hour, deviation) local date = makeDate(day, hour) coroutine.yield(date,route) end for day = 0,69 do local day_of_week = dayOfWeekFromDay(day) if day_of_week ~= "Saturday" and day_of_week ~= "Sunday" then checkTrip(day, "Berkeley to SF", 9, 1/2) checkTrip(day, "SF to Berkeley", 12 + 5, 1/2) end end for day = 70, math.huge do local day_of_week = dayOfWeekFromDay(day) if day_of_week ~= "Saturday" and day_of_week ~= "Sunday" then checkTrip(day, "Berkeley To San Leandro", 9, 1/2) checkTrip(day, "San Leandro to Berkeley", 12 + 5, 1/2) end end end) human_with_monday_and_thursday = { description = "Most weekdays, this person goes to SF at 9am and returns to Berkeley at 5pm. " .. "On Mondays, she goes to Oakland instead of Berkeley. " .. "On Thursdays, she makes an additional trip to Fremont at 2 pm, returning at 3:30 pm." } human_with_monday_and_thursday.lookAtTripPlanner = coroutine.wrap(function () local function checkTrip (day, route, base_hour, deviation) local hour = gaussianRandom(base_hour, deviation) local date = makeDate(day, hour) coroutine.yield(date,route) end for day = 0, math.huge do local day_of_week = dayOfWeekFromDay(day) if day_of_week ~= "Saturday" and day_of_week ~= "Sunday" then if day_of_week == "Monday" then checkTrip(day, "Berkeley to Oakland", 9, 1/2) checkTrip(day, "Oakland to Berkeley", 12 + 5, 1/2) elseif day_of_week == "Thursday" then checkTrip(day, "Berkeley to SF", 9, 1/2) checkTrip(day, "SF To Fremont", 12 + 2, 1/4) checkTrip(day, "Fremont to SF", 12 + 3.5, 1/4) checkTrip(day, "SF to Berkeley", 12 + 5, 1/2) else checkTrip(day, "Berkeley to SF", 9, 1/2) checkTrip(day, "SF to Berkeley", 12 + 5, 1/2) end end end end) human_with_multicheck = { description = "This is the same as the Monday/Thursday person, except she checks the trip planner ".. "six times for each trip, instead of once." } human_with_multicheck.lookAtTripPlanner = coroutine.wrap(function () local function checkTrip (day, route, base_hour, deviation) local hours = {} for i = 1,6 do hours[i] = gaussianRandom(base_hour, deviation) end table.sort(hours) for i = 1,6 do local date = makeDate(day, hours[i]) coroutine.yield(date,route) end end for day = 0, math.huge do local day_of_week = dayOfWeekFromDay(day) if day_of_week ~= "Saturday" and day_of_week ~= "Sunday" then if day_of_week == "Monday" then checkTrip(day, "Berkeley to Oakland", 9, 1/2) checkTrip(day, "Oakland to Berkeley", 12 + 5, 1/2) elseif day_of_week == "Thursday" then checkTrip(day, "Berkeley to SF", 9, 1/2) checkTrip(day, "SF To Fremont", 12 + 2.5, 1/4) checkTrip(day, "Fremont to SF", 12 + 3.5, 1/4) checkTrip(day, "SF to Berkeley", 12 + 5, 1/2) else checkTrip(day, "Berkeley to SF", 9, 1/2) checkTrip(day, "SF to Berkeley", 12 + 5, 1/2) end end end end) human_with_noise = { random_checks_per_real_check = 0.5, description = "Every weekday, this person goes to SF at 9am and returns to Berkeley at 5pm, " .. "but also does some amount of random checking around that time as well." } human_with_noise.lookAtTripPlanner = coroutine.wrap(function () local function checkTrip (day, route, base_hour, deviation) local number_of_stations = 20 local int_random, frac_random = math.modf(human_with_noise.random_checks_per_real_check) local number_of_checks_now = 1 + int_random + (math.random() < frac_random and 1 or 0) local hours = {} for i = 1,number_of_checks_now do hours[i] = gaussianRandom(base_hour, deviation) end table.sort(hours) local planned_check_number = math.random(number_of_checks_now) for i,hour in ipairs(hours) do local date = makeDate(day, hour) local check_route = (i == planned_check_number) and route or ("noise " .. math.random(number_of_stations)) coroutine.yield(date,check_route) end end for day = 0, math.huge do local day_of_week = dayOfWeekFromDay(day) if day_of_week ~= "Saturday" and day_of_week ~= "Sunday" then checkTrip(day, "Berkeley to SF", 9, 1/2) checkTrip(day, "SF to Berkeley", 12 + 5, 1/2) end end end) ------------------------------------------------------------------------- -- Util function makeDate (day, hour) local day_of_week = dayOfWeekFromDay(day) local date = { day = day, day_of_week = day_of_week, hour = hour, starting_hour = hour, ending_hour = hour, is_weekend = (day_of_week == "Saturday" or day_of_week == "Sunday") } setmetatable(date, { __tostring = function (date) local hour_string = string.format("%d:%02d", hour, 60 * (hour - math.floor(hour))) return string.format("%s (%d) %s", day_of_week, day, hour_string) end }) return date end local days_of_week = { "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday" } function dayOfWeekFromDay (day) return days_of_week[1 + math.mod(day,7)] end function gaussianRandom (mean, deviation) local x1, x2, w repeat x1 = 2 * math.random() - 1 x2 = 2 * math.random() - 1 w = x1 * x1 + x2 * x2 until w < 1 local unit_normal = x1 * math.sqrt(-2 * math.log(w) / w) return mean + deviation * unit_normal end function getKeyWithMaxValue (t) local best_v = -math.huge local best_k for k, v in pairs(t) do if v > best_v then best_v = v; best_k = k end end return best_k end function wrapString (s, width) width = width or 70 local result = "" local line = "" for word in string.gmatch(s, "%S+%s*") do if string.len(line) + string.len(word) > width then result = result .. line .. "\n" line = "" end line = line .. word end return result .. line end ------------------------------------------------------------------------- -- Go main()