Jekyll2023-10-02T21:39:12+00:00https://svenschmit.com/feed.xmlRandom ProjectionsPersonal blog Sven SchmitSven SchmitCODE @ MIT 2022 round up2022-10-25T07:00:00+00:002022-10-25T07:00:00+00:00https://svenschmit.com/code-roundup<p>After two long and rewarding days, CODE 2022 is a wrap.
Here’s a biased look at some of my personal highlights;
Keep in mind that this is far from a complete list: the conference was filled with quality content and due to the parallel nature of the talks, I missed more than half of them to begin with.</p>
<p><em>Note: the plenary talks should be posted on Youtube in the next few weeks, at which point I will update this post with links.
The figures used in this post are taken from the published work by the respective authors.</em></p>
<h1 id="jake-hoffman-on-the-effects-of-ux-on-understanding-of-uncertainty">Jake Hoffman on the effects of UX on understanding of uncertainty</h1>
<p>Jake Hoffman kicked off the conference with an important and often undervalued aspect of experimentation: the user interface really matters.
Starting from the “inference” versus “prediction” views popularized by <a href="https://projecteuclid.org/journals/statistical-science/volume-16/issue-3/Statistical-Modeling--The-Two-Cultures-with-comments-and-a/10.1214/ss/1009213726.full">Breiman</a>,
he showed that even expert’s interpretation of effectiveness of a drug in an experiment depends a lot on whether one uses standard error of the mean (underestimates variability and hence leans effective)
or the standard deviation of the sample (leans ineffective).</p>
<p>It is an important reminder that while we tend to nerd out on the mathematical problems,
we cannot forget about the user experience; we need to make sure results are accessible
and empower everyone to make good decisions, not only those with a PhD in Statistics.</p>
<h1 id="interference">Interference</h1>
<p>It’s clear that interference is clearly one of the most important open issues in experimentation, and more and more people realize how it is affecting their experiments.
For those less acquainted with this problem: interference occurs when treatment of one subject in the experiment affects outcomes of other subjects, for example due to a network structure (video chat is only valuable if your friends also have video chat) or a two-sided market (if you book a room, then it is no longer available for someone else).
The good news is that there is a lot of exciting research happening in this area.</p>
<h2 id="ilan-lobel-on-leveraging-shadow-prices">Ilan Lobel on leveraging shadow prices</h2>
<p><img src="assets/img/code/shadow.png" alt="shadow prices" /></p>
<p>Interference can take many forms, and is often difficult to model.
However, Ilan Lobel argues that the problem can be cleanly modeled in matching markets where a platform has strong control over the matching process (think Uber/Lyft).
Namely, the shadow prices of the optimization problem can be used to correct for the bias introduced
A beautiful linearization argument both shows the issue from using naive estimation of the average treatment effect (HTE) and how shadow prices address this.</p>
<p>At Stitch Fix I worked on both improving matches between clients and inventory using shadow prices, and on correcting for <a href="/virtual-warehouse/">interference</a> in the experiments we ran, but we never realized the connection.</p>
<p>The full paper is on <a href="https://arxiv.org/abs/2205.02274">Arxiv</a>.</p>
<h1 id="kris-ferreira-on-human-algorithm-collaboration">Kris Ferreira on human algorithm collaboration</h1>
<p>In a fantastic plenary talk, Kris Fereirra discussed her quest to understand how decision makers are interacting with algorithms to make better decisions, and what pitfalls they fall into.
By decomposing information into a public part, accessible by both decision maker and algorithm, and private part, only accessible to decision maker, she shows that decision makers struggle to understand when the algorithm provides strong predictions versus when they should rely on important private information to overrule the suggestions by the algorithm.</p>
<p>Instead, decision makers tend to use a convex combination of their own prediction and the algorithm’s suggestion independent of the quality of the quality of the algorithm’s prediction on a particular task.
This leads to suboptimal decision both when the algorithm is correct as well as when it is wrong.</p>
<p>There is clearly growing interest in understanding human-algorithm interaction, and I am curious to see how this area develops in the coming years.</p>
<h1 id="kelly-pisane-on-unbiased-impact-estimation">Kelly Pisane on unbiased impact estimation</h1>
<p>High on my list of best talks was Kelly Pisane from Booking.com
She discussed how to go about obtaining unbiased impact estimates from statistical significant experiment results.
The talk stands out because it demonstrates clearly that sometimes simple and practical suggestions trumpet complex mathematical modeling.</p>
<p>Every experiment platform faces a post-selection bias problem:
statistically significant effect estimates have an upward bias, but it is difficult to understand the magnitude of this bias.
Rather than trying to correct for this mathematically, we can do the following:
after making a decision, split the remaining users in the control group and use data from that “experiment” to get unbiased estimates of the treatment effect.
We avoid the issue that users may already have been exposed to the winning treatment.
Furthermore, it often takes some time for engineers to roll out the winner anyway, so you might as well run this second, impact, experiment.
Even if you are underpowered to understand the impact of a single experiment, the data can still be used in aggregate across experiments to understand the true impact of experiments.
Clever indeed!</p>
<h1 id="sequential-testing">Sequential testing</h1>
<p>There is clearly a lot of momentum in both academia and industry on sequential testing, with an entire session dedicated to it at the conference.
If you are less familiar with sequential testing:
unlike classical frequentist methods, this methodology allows for continuous monitoring and adaptive decision making with strong statistical guarantees.</p>
<p>Netflix is heavily invested in sequential testing, with multiple talks, but also Microsoft and Adobe are adopting the sequential paradigm.
That said, the sequential framework has pros and cons and does not fit every use case.
It will be very interesting to see how this space develops in the coming years</p>
<h2 id="michael-lindon-on-sequential-testing-procedures-for-regression">Michael Lindon on sequential testing procedures for regression</h2>
<p><img src="assets/img/code/sequential.png" alt="sequential confidence" /></p>
<p>It is hard to pick a particular talk from a session full of great content,
but personally I really enjoyed Michael’s talk on extending sequential procedures for regressions;
this enables us to combine sequential analysis with pre-experiment covariates and linear regression to reduce variance and thus speed up experiments.</p>
<p>Two aspects of this talk stood out to me in particular.</p>
<ol>
<li>It is often easy for research to lean theoretical, but here the practical use case is extremely obvious:
adjusting for pre-experiment data often speeds up experiments substantially.</li>
<li>The approach by Netflix is similar but also quite distinct from the way we at Eppo model this problem.</li>
</ol>
<p>The latter is indicative of a broader phenomenon that stood out at the conference:
we all have different approaches to the same problems and that makes it particular valuable to come together and be inspired by each other’s work.</p>
<p>Finally, worth mentioning is the call out by James McQueen that in practice it is very rare to observe a single outcome per user.
Instead, we observe a sequence of events from which we construct an outcome that evolves over the duration of an experiment,
but this often gets ignored in theoretical work on experimentation.
Having struggled with this myself, it is great to see this problem being more broadly.</p>
<p>The full paper is on <a href="https://arxiv.org/abs/2210.08589">Arxiv</a>.</p>
<h1 id="hamsa-bastani-on-covid-prevention-in-greece-using-bandits">Hamsa Bastani on COVID prevention in Greece using Bandits</h1>
<p><img src="assets/img/code/eva.png" alt="bandit" /></p>
<p>When it comes to bandits, everyone seems to agree on two things:</p>
<ol>
<li>They are beautiful to study from a theoretical standpoint</li>
<li>But in practice they are not used all that much due to a variety of complications</li>
</ol>
<p>Hamsa demonstrated that bandits can be powerful in practice, even under challenging circumstances.
Picture the setting: it is Spring 2020, the pandemic has just started, and in Greece, the tourist season is approaching rapidly.
Since tourism accounts for 25% of GDP, Greece literally cannot afford to remain in lockdown and decides to open borders for tourists.
How do you keep the public safe when you expect 30k to 100k tourists to arrive every day, while you only have a limited capacity of 8k daily PCR tests that have a 2-day turnover time.
Furthermore, the rate of COVID infections flares up in different countries across time so the testing results are only useful for a short period of time before becoming outdated.
How can you allocate the limited number of tests in a way that minimizes the number of infected tourists entering the country? Hamsa and team’s answer: use a contextual bandit.</p>
<p>Lo and behold, it worked and the team caught almost twice as many infection than random sampling would have.
Hamsa demonstrates there is hope for the practical applicability of bandits,
and likely saved affair number of lives in the process.</p>
<p>The full paper is published in <a href="https://www.nature.com/articles/s41586-021-04014-z">Nature</a>.</p>
<h1 id="john-cai-on-heterogeneous-treatment-effect-estimation">John Cai on heterogeneous treatment effect estimation</h1>
<p>Another big topic at CODE was understanding and estimating heterogeneous treatment effects (HTE), with clear applications in industry illustrated by talks from Snap, Netflix, and Meta.
Automatically finding segments of users for whom a treatment works particularly well, or poorly, has clear applications but is also a thorny problem that is far from solved.</p>
<p>John Cai from Snap gave an overview of Snap’s approach to HTE, bridging the gap between theory and practice:
first by evaluating the effectiveness of testing whether HTE exist in the first place.
If there are no heterogeneous effects, then a reasonable assumption is that the variance of control and treatment are the same.
This reduces the problem to testing the hypothesis that the variances are equal.</p>
<p>If we do detect a difference in variances, then we can focus on finding dimensions that contribute most to heterogeneity in treatment effects by
decomposing the total treatment effect variation into explained and idiosyncratic variation.
All of this was demonstrated using an actual experiment the team ran at Snap.</p>
<h1 id="wrapping-up">Wrapping up</h1>
<p>It is clear that interest in experimentation is growing rapidly,
and we are only scratching the surface on problems practitioners face.
In particular, I expect there to be a lot more work on interference, sequential testing, heterogeneous treatment effects, and understanding long term outcomes.
But as we nerd out on these problems, we should also keep the UX in mind:
we cannot just solve the mathematical problems, we also need to make solutions accessible to the non-expert so they feel empowered to make better decisions.</p>
<p>If you also attended CODE, let me know what your highlights are.
If you did not, hopefully this inspires you to visit Boston next year; the weather was surprisingly great!</p>Sven SchmitAfter two long and rewarding days, CODE 2022 is a wrap. Here’s a biased look at some of my personal highlights; Keep in mind that this is far from a complete list: the conference was filled with quality content and due to the parallel nature of the talks, I missed more than half of them to begin with.Beware of the Bayesian imposter2022-09-10T07:00:00+00:002022-09-10T07:00:00+00:00https://svenschmit.com/bayesian-paradox<p>Sometimes, statistical guarantees are not what they seem.
Here, we discuss the implications of a classic work that demonstrates a paradox with the Bayesian approach to experiment analysis:
when not careful, the experimenter runs the risk of running a frequentist analysis without realizing it.
This can have important implications: when combined with peeking, the credible intervals might not be so credible after all.</p>
<h2 id="the-great-debate">The Great Debate</h2>
<p>When you are building an <a href="https://www.geteppo.com/">experimentation platform</a>,
it’s impossible to avoid the debates between frequentist and Bayesian approaches to hypothesis testing.
Ardent supporters on both sides espouse the benefits of their approach, but when the debate has been raging for decades,
and between the most famous and brightest statisticians, it’s no surprise there aren’t any easy answers.</p>
<p>An oft highlighted issue of the classical frequentist approach is that of the peeking problem:
with modern tooling, it’s often easy to monitor the results of experiments while collecting data.
However, the frequentist guarantees <a href="https://www.evanmiller.org/how-not-to-run-an-ab-test.html">break when you peek at experiments</a>, leading to an inflated type I error rate:
you end up claiming there is a difference between variants far more often than you want to in the case both variants in fact have the same outcome distribution.
On the other hand, Bayesian supporters often highlight that the Bayesian approach does not suffer from the peeking problem,
claiming that this is one of the benefits of using a Bayesian testing regime over a frequentist one.</p>
<p>Statistics is often more subtle than it first seems, and so we have to wonder: is this too good to be true?
Famous Bayesian statistician <a href="https://www.statslab.cam.ac.uk/~apd/">Philip Dawid</a> wrote a fantastic paper in the 90s: <a href="https://projecteuclid.org/ebooks/institute-of-mathematical-statistics-lecture-notes-monograph-series/Multivariate-analysis-and-its-applications/Chapter/Selection-paradoxes-of-Bayesian-inference/10.1214/lnms/1215463797">“Selection paradoxes of Bayesian inferences”</a>, shared to me by <a href="https://www.stevehoward.org/">Steve Howard</a>,
that I highly recommend you read.
It’s fairly accessible, both in that there is no paywall, and that the reasoning is easy to follow.
In this work, he introduces what I would like to call the “Bayesian imposter”: someone who claims the benefits of the Bayesian approach,
while actually doing frequentist analysis.</p>
<p>Let’s put this in an experimentation context.
Before we do however, I want to note that this post is not about choosing sides: both frequentist and Bayesian approaches have their merits, and shine in different ways.
Rather, the work of Dawid helps us to understand and appreciate the subtleties in this debate.</p>
<h2 id="the-experimentation-setting">The experimentation setting</h2>
<p>First of all, let’s get on the same page regarding what an experiment is; we will keep it high level and consider the most basic setting; it is all we need.
Consider we want to test a new variant: treatment, versus the status quo: control, and we would like to figure out whether treatment is better than control according to some metric we care about.
We resort to some statistics to help us figuring this out: We randomly assign subjects to either treatment or control, and measure the outcome on said metric.
Now we use some statistical method to compare the two variants, e.g. by comparing the means using a T-test (frequentist) or comparing the posteriors by combining priors with a likelihood (Bayesian).</p>
<h2 id="frequentist-versus-bayesian-hypothesis-testing">Frequentist versus Bayesian hypothesis testing</h2>
<p>Before diving in, let us briefly recap the differences between frequentist and Bayesian hypothesis testing.</p>
<p>In the frequentist setting the underlying parameter fixed, while the observed data is considered as the random variable.
Thus, the frequentist thinks about guarantees on their inference methods through considering repeated draws of data.
For example, a confidence interval is a function of the data (thus also a random variable),
with the underlying idea that as we generate data, and hence confidence intervals, over and over, most times the
generated confidence intervals contains the parameter of interest.
But this has the unintuitive downside that a frequentist cannot give any certainty about a particular set of observations.</p>
<p>The Bayesian perspective flips this around: you are handed a set of observations, clearly those are fixed.
What you do not know is the value of the underlying parameter, so let’s consider that the random variable.
We encode our beliefs through a prior distribution for said parameter, and then combine that with the likelihood of
observing the data given the parameter to form a posterior distribution.
Now the posterior distribution allows us to reason about this particular outcome: conditioning on the data is not a problem,
because inference on the parameter is already conditioned on the data!</p>
<p>The last point is often quoted as one of the main benefits of the Bayesian approach to experimentation.
With the classical frequentist approach sample size needs to be fixed up front (we cannot be adaptive to data).
If we ignore this and decide whether to continue collecting data based on what we have observed so far, or are ready to make a decision, we can dramatically
increase the risk of a false positive.
This is different for in the Bayesian case, we can be adaptive, and make a decision at any point:
do we keep gathering more data, choose the treatment variant, or the control variant.
This follows from your elementary probability class: from the Bayesian perspective, the data is fixed (a constant) and only the
parameter we are interested in is considered a random variable. Conditioning on the data (a constant) does not impact results.</p>
<h2 id="dawids-paradox">Dawid’s paradox</h2>
<p>But is it really so black and white?
Dawid, a well known Bayesian statistician, posed the following paradox in the paper mentioned above:
The Bayesian approach sounds great: we can be adaptive to the data, while the frequentist cannot.
But now consider a Bayesian that uses a non-informative prior (or flat) prior: a prior for which every outcome is equally likely.
It is easy to see that when using such a non-informative prior, the frequentist and Bayesian approach lead to the exact same results:
the Bayesian approach multiplies the likelihood with the prior to get the posterior (modulo normalization).
If the prior is flat, then we are multiplying the likelihood by 1.
Hence, our inference is based only on the likelihood – exactly the same likelihood the frequentist is using!</p>
<p>But now the adaptivity is a bit odd: how can it be fine to be adaptive (or peek) when using the Bayesian approach,
while the frequentist is stuck waiting to collect all their data, despite the two approaches leading to exactly the same results?</p>
<h2 id="resolving-the-paradox">Resolving the paradox</h2>
<p>The culprit is clear: it is the flat prior that causes the Bayesian posterior to be the same as the likelihood on which
frequentist analysis is based.
As usual, there is no free lunch: the Bayesian approach has not magically solved the issue of peeking for free.
Rather, to solve the problem posed by peeking, we need to solve the problem of the prior: Bayesian analysis works when the prior works.
David Robinson demonstrates the <a href="http://varianceexplained.org/r/bayesian-ab-testing/">same point using simulations</a>.</p>
<p>Adaptivity, or the ability to peek, is very powerful and the Bayesian approach is far simpler than the <a href="https://arxiv.org/abs/1810.08240">frequentist alternative of sequential testing</a>.
But powerful results often require strong assumptions; in this case the need for a well-chosen prior.</p>
<p>In particular, when we are unsure what prior to choose, the suggestion to use a weak prior is much more common than to use a strong one, and this makes sense:
in the small data regime, a strong prior can dominate the posterior, which is undesirable especially when we do not feel strongly about the strong prior.
But the discussion above reveals that it is a balancing act: both weak and strong priors have downsides and hence carefully finding the correct balance is unavoidable.</p>
<h2 id="take-away">Take-away</h2>
<p>Bayesian and frequentist approaches and guarantees are fundamentally different, but under certain conditions, they lead to exactly the same results.
Dawid demonstrates clearly that this creates a paradox: when results are the same, approaches cannot have fundamentally different guarantees.
This is where the Bayesian imposter comes in: to truly enjoy the benefits of a Bayesian approach, one needs to carefully think about what prior to use,
otherwise, the experimenter might just a frequentist in disguise, and be ignorant of the risks that poses to the validity of their statistical claims.</p>
<p>More generally, when the experts have been debating the respective merits of frequentist versus Bayesian approaches,
it must be true that one cannot be clearly better than the other: then the debate would have been settled a long time ago.
When you experiment as a Bayesian, make sure you think carefully about your prior, or you might well be a frequentist without realizing it!</p>Sven SchmitSometimes, statistical guarantees are not what they seem. Here, we discuss the implications of a classic work that demonstrates a paradox with the Bayesian approach to experiment analysis: when not careful, the experimenter runs the risk of running a frequentist analysis without realizing it. This can have important implications: when combined with peeking, the credible intervals might not be so credible after all.Language modeling with Jax and RNNs2021-06-20T16:00:00+00:002021-06-20T16:00:00+00:00https://svenschmit.com/jax-language-model-rnn<p><a href="https://github.com/google/jax">Jax</a> is a relatively new Python library aimed as a drop in replacement for Numpy for machine learning research.
It sets itself apart due to its functional approach, which I find really enjoyable.
Recently I have been playing around with implementing a simple RNN using Flax to get beyond the basics of Jax, but without adding all the bells and whistles.</p>
<p>The goal is to create a character level language model using a simple RNN following the approach by Andrej Karpathy’s 2015 blog post <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness/">“The Unreasonable Effectiveness of Recurrent Neural Networks”</a>, but focus on Jax rather than creating an unreasonably effective network.
That is, given a reasonably sized text, say a book, we want to create a model that generates content one character at a time. We will find that it is relatively easy for the simplest of RNNs to learn how to string together a few words using a small network and few training epochs.</p>
<p>We will assume a basic familiarity with Jax. You can find the code on <a href="https://colab.research.google.com/drive/1Qw7zilRZVnVE4PMuKEGHrKBGaaSB3NWG">colab</a>, and I am by no means a Jax expert, so feel free to reach out with comments and suggestions.</p>
<h1 id="flax">Flax</h1>
<p>There are a few libraries built on top of Jax that help with creating neural networks specifically: Haiku by DeepMind, and Flax and Trax both developed by Google AI.
Without any particularly strong preference for any of them, we will use Flax for this task.</p>
<p>We use the following imports</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">jax</span>
<span class="kn">import</span> <span class="nn">jax.ops</span>
<span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">onp</span> <span class="c1"># convention: original numpy
</span>
<span class="kn">import</span> <span class="nn">flax</span>
<span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">linen</span> <span class="k">as</span> <span class="n">nn</span>
<span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">optim</span>
</code></pre></div></div>
<h1 id="preparing-data">Preparing data</h1>
<p>When prototyping any machine learning method, it is always helpful to have a trivial example data set. This is particularly beneficial when creating a neural network, as it is very easy miss subtle bugs that cause unexpected behavior. Overfitting a model to a small dataset is a great sanity check to ensure your model is actually doing what you think it is. In this case, let’s use the repetitive string <code class="language-plaintext highlighter-rouge">abcd...adcd...</code> as input.</p>
<p>To map characters to indices and back, we create a convenience function which, for lack of a better name, I call a bridge:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">id_bridge</span><span class="p">(</span><span class="n">iterable</span><span class="p">):</span>
<span class="s">""" provides mapping to and from ids """</span>
<span class="k">return</span> <span class="p">({</span><span class="n">elem</span><span class="p">:</span> <span class="nb">id</span> <span class="k">for</span> <span class="nb">id</span><span class="p">,</span> <span class="n">elem</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">iterable</span><span class="p">)},</span>
<span class="p">{</span><span class="nb">id</span><span class="p">:</span> <span class="n">elem</span> <span class="k">for</span> <span class="nb">id</span><span class="p">,</span> <span class="n">elem</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">iterable</span><span class="p">)})</span>
</code></pre></div></div>
<p>Next, we create some more functions that are helpful for mapping from characters to model input, and from model output back to characters.
Note that <code class="language-plaintext highlighter-rouge">jax.numpy</code> is for the most part a drop-in replacement for Numpy, but its functional nature does not allow us to update arrays in place (e.g. <code class="language-plaintext highlighter-rouge">A[0, 1] = 3</code>). Instead, we can use <code class="language-plaintext highlighter-rouge">jax.ops.index_update</code> to set values at particular indices.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">one_hot</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="p">):</span>
<span class="s">"""
create vector of size n with 1 at index i
"""</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">n</span><span class="p">)</span>
<span class="k">return</span> <span class="n">jax</span><span class="p">.</span><span class="n">ops</span><span class="p">.</span><span class="n">index_update</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="n">char</span><span class="p">):</span>
<span class="k">return</span> <span class="n">one_hot</span><span class="p">(</span><span class="n">char_to_id</span><span class="p">[</span><span class="n">char</span><span class="p">],</span> <span class="nb">len</span><span class="p">(</span><span class="n">char_to_id</span><span class="p">))</span>
<span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">id_to_char</span><span class="p">):</span>
<span class="c1"># for simplicity, pick the most likely character
</span> <span class="c1"># this can be replaced by sampling weighted
</span> <span class="c1"># by the probability of each character
</span> <span class="k">return</span> <span class="n">id_to_char</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">predictions</span><span class="p">))]</span>
</code></pre></div></div>
<p>The decode function takes the model output: predicted probabilities for each character that it will be the next character in the sequence, and returns the character the model finds most likely.</p>
<h1 id="a-simple-recurrent-model">A simple recurrent model</h1>
<p>With that setup out of the way, we can code up a simple recurrent model.
While some recurrent cells are included in Flax, for the sake of learning we implement our own.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">RNNCell</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="o">@</span><span class="n">nn</span><span class="p">.</span><span class="n">compact</span>
<span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="c1"># Wh @ h + Wx @ x + b can be efficiently computed
</span> <span class="c1"># by concatenating the vectors and then having a single dense layer
</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">state</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span>
<span class="n">new_state</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">state</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])(</span><span class="n">x</span><span class="p">))</span>
<span class="k">return</span> <span class="n">new_state</span>
</code></pre></div></div>
<p><code class="language-plaintext highlighter-rouge">nn.compact</code> allows us to define the parameters in the forward pass of the model, rather than separately: Flax generates the parameters later based on a sample input.
Otherwise, this code should look rather familiar to similar definitions using other libraries, such as PyTorch.</p>
<p>Next, we stack three RNNCells on top of each other to create our Character RNN.
The <code class="language-plaintext highlighter-rouge">init_state</code> method is not required, but will turn out to be a rather convenient way to initialize the state when we want to generate new sequences.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ChaRNN</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
<span class="n">state_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="n">vocab_size</span><span class="p">:</span> <span class="nb">int</span>
<span class="o">@</span><span class="n">nn</span><span class="p">.</span><span class="n">compact</span>
<span class="k">def</span> <span class="nf">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">i</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">one_hot</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="bp">self</span><span class="p">.</span><span class="n">vocab_size</span><span class="p">)</span>
<span class="n">new_state</span> <span class="o">=</span> <span class="p">[]</span>
<span class="c1"># a rather naive way of stacking multiple RNN cells
</span> <span class="n">new_state_1</span> <span class="o">=</span> <span class="n">RNNCell</span><span class="p">()(</span><span class="n">state</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">x</span><span class="p">)</span>
<span class="n">new_state_2</span> <span class="o">=</span> <span class="n">RNNCell</span><span class="p">()(</span><span class="n">state</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">new_state_1</span><span class="p">)</span>
<span class="n">new_state_3</span> <span class="o">=</span> <span class="n">RNNCell</span><span class="p">()(</span><span class="n">state</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">new_state_2</span><span class="p">)</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">vocab_size</span><span class="p">)(</span><span class="n">new_state_3</span><span class="p">))</span>
<span class="k">return</span> <span class="p">[</span><span class="n">new_state_1</span><span class="p">,</span> <span class="n">new_state_2</span><span class="p">,</span> <span class="n">new_state_3</span><span class="p">],</span> <span class="n">predictions</span>
<span class="k">def</span> <span class="nf">init_state</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="c1"># a convenient way to initialize the state
</span> <span class="k">return</span> <span class="p">[</span>
<span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">state_size</span><span class="p">),</span>
<span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">state_size</span><span class="p">),</span>
<span class="n">np</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">state_size</span><span class="p">)</span>
<span class="p">]</span>
</code></pre></div></div>
<p>Let’s write a function that can generate new sequences by sampling from the model.
The following function is simple but not very efficient</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">bridge</span><span class="p">,</span> <span class="n">initial</span><span class="o">=</span><span class="s">''</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="mi">100</span><span class="p">):</span>
<span class="n">char_to_id</span><span class="p">,</span> <span class="n">id_to_char</span> <span class="o">=</span> <span class="n">bridge</span>
<span class="n">state</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">init_state</span><span class="p">()</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">initial</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">initial</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
<span class="k">for</span> <span class="n">char</span> <span class="ow">in</span> <span class="n">initial</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>
<span class="n">state</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">char_to_id</span><span class="p">[</span><span class="n">char</span><span class="p">])</span>
<span class="n">next_char</span> <span class="o">=</span> <span class="n">initial</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">max_length</span><span class="p">):</span>
<span class="n">state</span><span class="p">,</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">char_to_id</span><span class="p">[</span><span class="n">next_char</span><span class="p">])</span>
<span class="n">next_char</span> <span class="o">=</span> <span class="n">decode</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">id_to_char</span><span class="p">)</span>
<span class="n">output</span> <span class="o">+=</span> <span class="n">next_char</span>
<span class="k">return</span> <span class="n">output</span>
</code></pre></div></div>
<p>Finally, we initialize a model and test whether it can indeed generate a random sample.
Jax handles <a href="https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers">randomness</a> differently from Numpy in that explicit keys have to be used.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">state_size</span> <span class="o">=</span> <span class="mi">8</span>
<span class="c1"># randomness is handled using explicit keys in Jax
</span><span class="n">key</span><span class="p">,</span> <span class="n">subkey</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">ChaRNN</span><span class="p">(</span><span class="n">state_size</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">char_to_id</span><span class="p">))</span>
<span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">init</span><span class="p">(</span><span class="n">subkey</span><span class="p">,</span> <span class="n">model</span><span class="p">.</span><span class="n">init_state</span><span class="p">(),</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Model state size: </span><span class="si">{</span><span class="n">model</span><span class="p">.</span><span class="n">state_size</span><span class="si">}</span><span class="s">, vocab size: </span><span class="si">{</span><span class="n">model</span><span class="p">.</span><span class="n">vocab_size</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="c1"># output: Model state size: 8, vocab size: 5
</span>
<span class="c1"># run a single example through the model to test that it works
</span><span class="n">new_state</span><span class="p">,</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">model</span><span class="p">.</span><span class="n">init_state</span><span class="p">(),</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">assert</span> <span class="n">predictions</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="n">model</span><span class="p">.</span><span class="n">vocab_size</span>
<span class="c1"># calling sample on random model leads to random output
</span><span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="p">(</span><span class="n">char_to_id</span><span class="p">,</span> <span class="n">id_to_char</span><span class="p">),</span> <span class="s">'abc'</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="c1"># output: 'abcadbaadbadd'
</span></code></pre></div></div>
<h1 id="training">Training</h1>
<p>Now that we have verified the model code works, it is time to focus on optimizing the parameters.
Recall that we want the model to predict the next character based on the sequence of characters seen so far.
We create the following function to batch the input and creates a sequence of inputs and another sequence of targets to predict, which
is the same as the input sequence but shifted by one.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">chunker</span><span class="p">(</span><span class="n">seq</span><span class="p">,</span> <span class="n">size</span><span class="p">):</span>
<span class="s">"""
chunks a sequences into two subsequences
one for inputs, another for targets, by
shifting the input by 1
"""</span>
<span class="n">n</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">seq</span><span class="p">)</span>
<span class="n">p</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">while</span> <span class="n">p</span> <span class="o">+</span> <span class="mi">1</span> <span class="o"><=</span> <span class="n">n</span><span class="p">:</span>
<span class="c1"># ensure the last chunk is of equal size
</span> <span class="k">yield</span> <span class="n">seq</span><span class="p">[</span><span class="n">p</span><span class="p">:</span><span class="nb">min</span><span class="p">(</span><span class="n">n</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">p</span><span class="o">+</span><span class="n">size</span><span class="p">)],</span> <span class="n">seq</span><span class="p">[(</span><span class="n">p</span><span class="o">+</span><span class="mi">1</span><span class="p">):(</span><span class="n">p</span><span class="o">+</span><span class="n">size</span><span class="o">+</span><span class="mi">1</span><span class="p">)]</span>
<span class="n">p</span> <span class="o">+=</span> <span class="n">size</span>
</code></pre></div></div>
<p>Creating the loss function over a sequence of inputs is where Jax really shines in my opinion: we start by implementing the loss
for a single example, and then use Jax functions to vectorize, differentiate and compile.</p>
<p>Note we have to unroll the RNN to compute gradients, and at some point have to truncate the unrolled RNN.
In our case, we compute gradients over a batch, which will define how far the RNN will unroll.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">cross_entropy_loss</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span>
<span class="c1"># note we compute the loss for a single example.
</span> <span class="c1"># we will use vmap below to vectorize
</span> <span class="k">return</span> <span class="o">-</span><span class="n">np</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">predictions</span><span class="p">[</span><span class="n">label</span><span class="p">])</span>
<span class="k">def</span> <span class="nf">rnn_loss</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
<span class="c1"># use lax.scan to efficiently generate a loop over the inputs
</span> <span class="c1"># this function returns thefinal state, and predictions for every step
</span> <span class="c1"># note: scan input array needs have shape [length, 1]
</span> <span class="n">final_state</span><span class="p">,</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">lax</span><span class="p">.</span><span class="n">scan</span><span class="p">(</span><span class="k">lambda</span> <span class="n">state</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">model</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">x</span><span class="p">),</span>
<span class="n">state</span><span class="p">,</span>
<span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">inputs</span><span class="p">]).</span><span class="n">T</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">cross_entropy_loss</span><span class="p">)(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span><span class="n">targets</span><span class="p">]).</span><span class="n">T</span><span class="p">))</span>
<span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">final_state</span>
<span class="c1"># we want both the loss an gradient, we set has_aux because rnn_loss also return final state
# use static_argnums=1 to indicate that the model is static;
# a different model input will require recomplication
# finally, we jit the function to improve runtime
</span><span class="n">rnn_loss_grad</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">rnn_loss</span><span class="p">,</span> <span class="n">has_aux</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
<span class="n">static_argnums</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>
<h2 id="optimization">Optimization</h2>
<p>We use <code class="language-plaintext highlighter-rouge">flax.optim</code> to handle the gradient steps.
Let’s define the following functions to deal with a single batch, which becomes trivial, and looping through batches to compute an epoch of updates
Note that in this case the gradients are only</p>
<p>Note for each epoch we start with the initial state, and we propagate states across all batches in an epoch.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">batch_step</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span>
<span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">state</span><span class="p">),</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">rnn_loss_grad</span><span class="p">(</span><span class="n">optimizer</span><span class="p">.</span><span class="n">target</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
<span class="n">new_optimizer</span> <span class="o">=</span> <span class="n">optimizer</span><span class="p">.</span><span class="n">apply_gradient</span><span class="p">(</span><span class="n">grad</span><span class="p">)</span>
<span class="k">return</span> <span class="n">new_optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">state</span>
<span class="k">def</span> <span class="nf">epoch_step</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span>
<span class="n">state</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">init_state</span><span class="p">()</span>
<span class="n">total_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">chunker</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)):</span>
<span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">batch_step</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span>
<span class="n">total_loss</span> <span class="o">+=</span> <span class="n">loss</span>
<span class="k">return</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">total_loss</span> <span class="o">/</span> <span class="p">(</span><span class="n">n</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>
<p>Finally, we initialize the optimizer in the following way:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">optimizer_def</span> <span class="o">=</span> <span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">,</span>
<span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">)</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer_def</span><span class="p">.</span><span class="n">create</span><span class="p">(</span><span class="n">initial_params</span><span class="p">)</span>
</code></pre></div></div>
<p>When we put this together in a simple training function (see the <a href="https://colab.research.google.com/drive/1Qw7zilRZVnVE4PMuKEGHrKBGaaSB3NWG#scrollTo=bI885lfVmKab">Colab</a> for details) and run it on our sample input <code class="language-plaintext highlighter-rouge">abcd...abcd...</code>, we
see that the model quickly learns how to repeat the pattern:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Training RNN on 'abcd...abcd...'
Vocabulary size: 5
State size: 8
Adam optimizer parameters
learning_rate=0.002
weight_decay=0.000
====================
Epoch: 0 loss: 1.643 time: 1.77
Sample: abcd.bd..b..d..d..d..d..d..d..d..d..d..d..d..d..d..d..
Epoch: 5 loss: 1.528 time: 0.07
Epoch: 10 loss: 1.430 time: 0.06
Epoch: 15 loss: 1.348 time: 0.06
Epoch: 20 loss: 1.278 time: 0.06
Epoch: 25 loss: 1.207 time: 0.06
Epoch: 30 loss: 1.126 time: 0.06
Epoch: 35 loss: 1.041 time: 0.06
Epoch: 40 loss: 0.958 time: 0.06
Sample: abcd....bb............................................
Epoch: 45 loss: 0.882 time: 0.05
Epoch: 50 loss: 0.811 time: 0.06
Epoch: 55 loss: 0.744 time: 0.06
Epoch: 60 loss: 0.680 time: 0.06
Epoch: 65 loss: 0.622 time: 0.06
Epoch: 70 loss: 0.569 time: 0.08
Epoch: 75 loss: 0.522 time: 0.07
Epoch: 80 loss: 0.479 time: 0.06
Sample: abcd...abcd...abcd...abcd...abcd...abcd...abcd...abcd.
</code></pre></div></div>
<h1 id="using-real-data">Using real data</h1>
<p>The trivial example above should give us some confidence that the code we wrote works as expected.
For a bit of fun, let’s run the same code on a slightly more interesting dataset.
<a href="https://www.gutenberg.org/">Project Gutenberg</a> hosts a library of free ebooks, from which I pulled a copy of The Metamorphosis by Franz Kafka.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">kafka</span> <span class="o">=</span> <span class="n">get_text</span><span class="p">(</span><span class="s">'kafka.txt'</span><span class="p">)</span>
<span class="n">state_size</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">vocab_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">kafka</span><span class="p">))</span>
<span class="n">key</span><span class="p">,</span> <span class="n">subkey</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">ChaRNN</span><span class="p">(</span><span class="n">state_size</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">)</span>
<span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">init</span><span class="p">(</span><span class="n">subkey</span><span class="p">,</span> <span class="n">model</span><span class="p">.</span><span class="n">init_state</span><span class="p">(),</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Model state size: </span><span class="si">{</span><span class="n">model</span><span class="p">.</span><span class="n">state_size</span><span class="si">}</span><span class="s">, vocab size: </span><span class="si">{</span><span class="n">model</span><span class="p">.</span><span class="n">vocab_size</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="n">result</span><span class="p">,</span> <span class="n">losses</span><span class="p">,</span> <span class="n">bridge</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">kafka</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="mi">400</span><span class="p">,</span>
<span class="n">batch_size</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
<span class="n">max_epoch_size</span><span class="o">=</span><span class="mi">10000</span><span class="p">,</span>
<span class="n">weight_decay</span><span class="o">=</span><span class="mf">1e-7</span><span class="p">,</span>
<span class="n">sample_every</span><span class="o">=</span><span class="mi">25</span><span class="p">,</span> <span class="n">sample_prompt</span><span class="o">=</span><span class="s">"Gregor"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Model state size: 128, vocab size: 64
Training RNN on 'I
One mo...'
Vocabulary size: 64
State size: 128
Adam optimizer parameters
learning_rate=0.002
weight_decay=0.000
====================
Epoch: 0 loss: 3.134 time: 7.32
Sample: Gregor to to oe to oe oe to te oe to te oe oe t
Epoch: 5 loss: 2.329 time: 3.34
Epoch: 10 loss: 2.042 time: 3.34
Epoch: 15 loss: 2.031 time: 3.24
Epoch: 20 loss: 1.756 time: 3.27
Epoch: 25 loss: 1.663 time: 3.25
Sample: Gregor and her her the was not her the was not her the w
Epoch: 30 loss: 1.687 time: 3.22
Epoch: 35 loss: 1.641 time: 3.23
Epoch: 40 loss: 1.527 time: 3.27
Epoch: 45 loss: 1.590 time: 3.24
Epoch: 50 loss: 1.614 time: 3.25
Sample: Gregor was of the father was of the father was of the fa
Epoch: 55 loss: 1.488 time: 3.25
Epoch: 60 loss: 1.411 time: 3.25
Epoch: 65 loss: 1.495 time: 3.35
Epoch: 70 loss: 1.589 time: 3.21
Epoch: 75 loss: 1.467 time: 3.30
Sample: Gregor’s bone the could have the could have the could ha
Epoch: 80 loss: 1.459 time: 3.24
Epoch: 85 loss: 1.469 time: 3.31
Epoch: 90 loss: 1.513 time: 3.37
Epoch: 95 loss: 1.400 time: 3.31
Epoch: 100 loss: 1.460 time: 3.27
Sample: Gregor’s sister was nothing the door and state and state
Epoch: 105 loss: 1.369 time: 3.27
Epoch: 110 loss: 1.429 time: 3.29
Epoch: 115 loss: 1.389 time: 3.29
Epoch: 120 loss: 1.366 time: 3.30
Epoch: 125 loss: 1.459 time: 3.23
Sample: Gregor’s mother would be the table to the table to the t
Epoch: 130 loss: 1.330 time: 3.24
Epoch: 135 loss: 1.387 time: 3.30
Epoch: 140 loss: 1.335 time: 3.29
Epoch: 145 loss: 1.325 time: 3.25
Epoch: 150 loss: 1.350 time: 3.23
Sample: Gregor’s father to his father to his father to his fathe
</code></pre></div></div>
<p>After running for 400 epochs, and seeing the training loss flatten, we can sample a longer snippet to see where the model has landed:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">sample</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">result</span><span class="p">.</span><span class="n">target</span><span class="p">,</span> <span class="n">bridge</span><span class="p">,</span> <span class="n">initial</span><span class="o">=</span><span class="s">'Gregor'</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="mi">500</span><span class="p">)</span>
</code></pre></div></div>
<p>outputs</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Gregor’s father and he had been to his father and he had been
to his father and he had been to his father and he had been
to his father and he had been to his father and he had been
to his father and he had been to his father and he had been
to his father and he had been to his father and he had been
...
</code></pre></div></div>
<p>While the model is able to string together words quite quickly, the deterministic sampling converges to equilibria of repeating patterns.</p>
<h1 id="wrapping-up">Wrapping up</h1>
<p>Of course, from here on out we can improve on all aspects of this modeling exercise, but for now, I hope this has been useful in getting a better grasp of Jax.
Personally, I find Jax a pleasure to work with: the functional style makes it easy to write transparent code.
Furthermore, if you have experience with Numpy, the syntax obviously feels familiar.
It’s true that the handling of randomness requires a bit of getting used to,
and libraries such as Flax do some magic under the hood to help make the functional approach practical, but both are for the greater good.</p>
<p><em>If you have questions, comments or find an error, please reach out via Twitter or email!</em></p>Sven SchmitJax is a relatively new Python library aimed as a drop in replacement for Numpy for machine learning research. It sets itself apart due to its functional approach, which I find really enjoyable. Recently I have been playing around with implementing a simple RNN using Flax to get beyond the basics of Jax, but without adding all the bells and whistles.Experimentation with resource constraints2020-11-19T04:26:24+00:002020-11-19T04:26:24+00:00https://svenschmit.com/virtual-warehouse<p>Based on work at Stitch Fix around experimentation with resource constraints and our introduction of the “virtual warehouse”, Greg Novak, Dave Spiegel and I wrote <a href="https://multithreaded.stitchfix.com/blog/2020/11/18/virtual-warehouse/">a post on the Multithreaded blog</a>.</p>
<p>We introduce the post with the following thought experiment:</p>
<blockquote>
<p>Suppose a group of squirrels is considering two possible strategies to survive a harsh winter: either (A) gorge themselves on acorns in the fall and try to make it through the winter on whatever they can find, or (B) bury some acorns in the fall which will then be ready to be dug up in the winter.</p>
</blockquote>
<blockquote>
<p>Having read a bit about data science, they might choose to A/B test these strategies, randomizing squirrels into two groups. Of course, if the squirrels are sharing the same part of the woods, we can immediately see which group will have a better chance of remaining well fed until springtime — the squirrels of group A with their “greedy” strategy (always optimizing instantaneous rate of calorie consumption) get to stuff themselves all autumn long without setting aside any nuts for the future and continue the eating through the winter by digging up the nuts from their thrifty buddies in group B who have saved acorns throughout the forest floor. Group B has a strategy that actually might be superior if it were rolled out to all squirrels, but if they share the same region as group A, their sacrificing some feasting in the fall won’t lead them to be any better off in the winter.</p>
</blockquote>
<blockquote>
<p>In contrast, if the two experimental groups were placed in separate forests, they would get a better measure of what it would be like to roll out a strategy for all squirrels. Maybe strategy B — saving some acorns for later — is better for squirreldom than greedy strategy A, and maybe not; but the only way an A/B test could possibly reveal B as the winner is if the two groups are not competing for the same underlying resource. Thus, the randomized assignment of squirrels to the two strategies is not good enough; we also have to ensure resources of the two groups are independent.</p>
</blockquote>
<p>For the entire post, head over the <a href="https://multithreaded.stitchfix.com/blog/2020/11/18/virtual-warehouse/">Multithreaded blog</a>.</p>Sven SchmitBased on work at Stitch Fix around experimentation with resource constraints and our introduction of the “virtual warehouse”, Greg Novak, Dave Spiegel and I wrote a post on the Multithreaded blog.Large scale experimentation2020-07-08T04:26:24+00:002020-07-08T04:26:24+00:00https://svenschmit.com/large-scale-experimentation<p><a href="http://web.stanford.edu/~rjohari/">Ramesh Johari</a>, <a href="https://www.linkedin.com/in/virag-shah-bb986419/">Virag Shah</a>, and myself wrote an academic paper on “large scale experimentation” that introduces a framework on how to think about optimal testing when there are a lot of possible experiments to run.</p>
<p>To accompany the more technical academic paper, I wrote <a href="https://multithreaded.stitchfix.com/blog/2020/07/07/large-scale-experimentation/">a blog post with a more intuitive exposition of the main ideas for the Multithreaded blog</a>.</p>
<p>My favorite aspect of this blog post is the interactive visualization by <a href="https://www.linkedin.com/in/brianedwardcoffey/">Brian Coffey</a> where the reader is encouraged to explore different testing strategies.</p>Sven SchmitRamesh Johari, Virag Shah, and myself wrote an academic paper on “large scale experimentation” that introduces a framework on how to think about optimal testing when there are a lot of possible experiments to run.Multiple hypothesis testing2015-10-16T04:26:24+00:002015-10-16T04:26:24+00:00https://svenschmit.com/multiple-hypothesis-testing<p>Inspired by the lack of blog posts about multiple hypothesis testing, <a href="https://multithreaded.stitchfix.com/blog/2015/10/15/multiple-hypothesis-testing/">I wrote a post on multiple hypothesis testing for the Multithreaded blog</a>.
The post covers three aspects of hypothesis testing:</p>
<ul>
<li>The difference between p-values and the “probability of being right”</li>
<li>How to combine many hypotheses into a single test</li>
<li>How to deal with multiple hypotheses individually</li>
</ul>
<p>With the help of the Algo UI team, it also contains interactive visualization that showcase the differences between frequentist and Bayesian viewpoints on testing, and how the Benjamini-Hochberg procedure controls the false discovery rate.</p>
<p>The material is based on the fantastic Stanford <a href="https://statweb.stanford.edu/~candes/teaching/stats300c/">Statistics 300C</a> course taught by Prof. Candes.</p>Sven SchmitInspired by the lack of blog posts about multiple hypothesis testing, I wrote a post on multiple hypothesis testing for the Multithreaded blog. The post covers three aspects of hypothesis testing: