I dropped a neural net
Recently, I worked through my first Jane Street puzzle. Not one of their math puzzles on the main website (I still have trouble with those), but one of their ML puzzles on their Hugging Face repo. The current puzzle (as of this writing) is called "I dropped a neural net".
The puzzle
Oh no! I dropped an extremely valuable trading model and it fell apart into linear layers! I need to rebuild it before anyone notices, but I can't remember how these pieces go together, or how it was trained.
All I have left are the pieces of the model and some historical data. Can you help me figure out how to put it back together?
Luckily I still have the source code of the layers that the neural network is made of. They look like this:
class Block(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.inp = nn.Linear(in_dim, hidden_dim)
self.activation = nn.ReLU()
self.out = nn.Linear(hidden_dim, in_dim)
def forward(self, x):
residual = x
x = self.inp(x)
x = self.activation(x)
x = self.out(x)
return residual + x
class LastLayer(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.layer = nn.Linear(in_dim, out_dim)
def forward(self, x):
return self.layer(x)
The solution to this puzzle is a permutation. For each index from 0 to 96, you give the index of the piece that is applied in that position.
Summary
To put it simply, the model has been shattered into linear layer files, and we need to reconstruct the exact order of the original forward pass.
The answer to the puzzle is a permutation, but this is not a permutation puzzle in itself.
The problem is that there are ways to order the pieces or roughly possible permutations. To put that into perspective, there are only about atoms in the observable universe. If we had supercomputers each computing a permutation at Planck time (which is checks per second), it would still take us 30 quintillion years to get that permutation, which is over 2 billion times the current age of the universe.
Clearly, we need to look into the model architecture and find some clues. The good news is that the architecture leaks the model structure, and the historical data from the CSV file lets us check whether a proposed reconstruction behaves like the original model.
Step 1: Understanding the model architecture
Refer back to the code provided in the problem statement. Each Block is a residual block. The forward pass is:
def forward(self, x):
residual = x
x = self.inp(x)
x = self.activation(x)
x = self.out(x)
return residual + x
Mathematically, this is
In other words,
The model keeps a current internal state . Each residual block then computes a correction to that state and adds it back. The skip connection is the part.
This matters because it constrains the shapes of the layers inside each block. If the model state has a dimension , then the block has
The first layer expands the state from dimension to dimension . The second layer then compresses it back from to , so that the correction can be added back to the original -dimensional state.
This means that if we can infer and , we can identify which files are the first half of a residual block and which files are the second half.
Step 2: Inferring the dimensions
The historical CSV is useful even before doing any modeling, as it tells us the input dimension. The CSV contains feature columns of the form:
measurement_0, measurement_1, ..., measurement_47
It tells us that there are input features, therefore the model input state is
The CSV also contains a pred column, which is the scalar output of the original model. So the final output dimension is . From the architecture, the final layer must be
I then inspected the provided .pth files to find the tensor shapes. The weights came in three shapes:
We know that nn.Linear(in_features, out_features) stores its weight matrix with the shape (out_features, in_features). So these shapes correspond to
This tells us the exact architecture dimensions
So every residual block must be
and the model ends with
So this leaves us with the following:
So now we know
- Which piece is the final layer
- Which pieces are possible
W_inlayers - Which pieces are possible
W_outlayers
Now the next step is to pair each with its matching and order the residual blocks that we recover.
Step 3: Pairing the two halves of each residual block
The next step is to determine which W_in belong with each W_out. We know that each residual block contains two trained matrices
In other words:
Two matrices inside the same block were trained together, so they should have a detectable relationship. A random W_in from one block and a random W_out from another block should fit together less cleanly.
For each candidate pair, I looked at W_out @ W_in. The dimensions multiply out as follows:
Essentially we are asking, "If two layers were chained together, what transformation would they apply?" Now of course, this is not the exact residual block because the real block is
The product W_out @ W_in ignores the ReLU and the biases. But it still approximates the linear part of the residual update.
For a correct pair, the product W_out @ W_in represents the approximate linear behavior of one coherent residual correction. For an incorrect pair, it combines the first half of one block with the second half of a different block. That product should be less structured.
Empirically, the correct pairs had a much stronger diagonal / structured pattern.


The correct pair is not a perfect identity matrix, but the diagonal signal is noticeably stronger. The incorrect pair is significantly noisier.
Why do we care about the diagonal structure?
The residual block takes a -dimensional state and returns a -dimensional correction that gets added back to that same state. So it is natural for the learned correction to preserve some coordinate-wise structure. A diagonal entry in W_out @ W_in measures how much input coordinate contributes back to output coordinate through the approximate linearized residual branch. So a stronger diagonal means that the coordinate of the input tends to influence coordinate of the residual correction.
Of course, the block can still have off-diagonal interactions, but correct pairs had much more coherent structure than mismatched pairs.
One simple score I used was
def pair_score(W_in, W_out):
M = W_out @ W_in
diag = np.diag(M)
off_diag = M - np.diag(diag)
diag_strength = np.sum(np.abs(diag))
off_diag_strength = np.linalg.norm(off_diag)
return diag_strength / (off_diag_strength + 1e-12)
This score measures diagonal mass relative to off-diagonal mass.
I computed this score for every possible W_in, W_out pair. There are only possible candidate pairs, so this step is actually computationally feasible.
Once every candidate pair has a score, we need a one-to-one matching. Each W_in should be paired with exactly one W_out, and each W_out should be used exactly once.
This is a bipartite assignment problem, so I used the Hungarian algorithm
from scipy.optimize import linear_sum_assignment
rows, cols = linear_sum_assignment(-scores)
The negative sign is because linear_sum_assignment solves a minimization problem by default, whereas I wanted to maximize the pairing score.
After this step, the 96 hidden pieces were recovered as 48 residual blocks.
So the problem has been reduced even further. We originally had 97 random files, and now we have 48 residual blocks and one final layer which is appended at the end. Now we have to order all of our residual blocks.
Step 4: Ordering the recovered residual blocks
The final layer is known by its shape. The remaining problem is to determine the order of the 48 residual blocks.
Naively, this is still huge, as there are permutations for these residual blocks. However, this is where the CSV file provided is particularly useful.
The historical data can be used as a scoring function. For any candidate block order, I can reconstruct a full model
Then I compare that reconstructed prediction to the pred column from the CSV. The loss is mean squared error:
The correct ordering should produce MSE equal to zero up to numerical precision.
Seeding the order with the output-layer norm
The heuristic that worked well was sorting blocks by the Frobenius norm of their output projection:
The Frobenius norm of a matrix is:
It is just a measure of the overall size of the matrix. For each recovered block, I computed
score = np.linalg.norm(W_out)
and sorted the blocks by that score. It turns out that this was surprisingly close to the true order.
My intuition was that residual networks are not arbitrary compositions of matrices. Each block is a learned perturbation added to the current state, and the scale of that perturbation changes across depth. The norm of W_out is not enough to prove the order, but it gives a very good seed.
This norm-based ordering is the step that makes the local search fairly simple. If the seed were random, adjacent swaps would not be enough. But if the seed is already almost correct, adjacent swaps can repair the remaining local inversions.
Step 5: Finalizing the order with adjacent swaps
After sorting by , I ran an adjacent-swap local search.
The general procedure for the search is as follows:
- Start with the norm-sorted order
- Compute the MSE against historical predictions
- Try swapping neighboring blocks
- Keep the swap if it improves MSE
- Repeat until no adjacent swap improves the loss
def adjacent_swap_search(order):
best = loss(order)
while True:
improved = False
for i in range(len(order) - 1):
candidate = order.copy()
candidate[i], candidate[i + 1] = candidate[i + 1], candidate[i]
candidate_loss = loss(candidate)
if candidate_loss < best:
order = candidate
best = candidate_loss
improved = True
if not improved:
break
return order, best
This is somewhat like bubble sort, except the comparison is not about checking if the number at index is greater than the number at index . Instead the comparison its making is checking if swapping two neighboring residual blocks makes the reconstructed model closer to the original model.
The MSE curve looked like this

Most accepted swaps gradually reduce the error. Then near the end, there is a huge drop.
That final drop is the interesting part. It means the model was already almost structurally correct, but one critical local inversion was still wrong. Because residual blocks are applied sequentially, an ordering mistake changes the internal state, and all later blocks operate on that changed state. Once the last important local inversion was fixed, the entire forward pass lined up with the original model and the error collapsed to numerical precision.
The final MSE was 3.369516376528804e-14
Step 6: Validation
To validate my findings, I compared the original predictions to the reconstructed predictions.

If the model were reconstructed incorrectly, the points would drift away from the diagonal. Instead, they lie pretty much exactly on the identity line. Thus, I can conclude that the model has been reconstructed.
Step 7: Getting the final permutation
As stated previously, each residual block has two piece indices. The first is an expand piece and the second is a compress piece. So I walked through the recovered block order and emitted:
expand_1, compress_1, expand_2, compress_2, ..., expand_48, compress_48
Then I appended the final layer to the end of the list, and I had my finalized permutation. I submitted my comma-separated permutation to the repo, and was met with the following message.

Final thoughts
What I liked about this puzzle is that the model is broken, but not in a way that destroys all information. The CSV reveals the input and output dimensions. The tensor shapes reveal the roles of the layers. The matrix products reveal which halves were trained together. The historical predictions reveal whether the ordering is correct.
I ended up submitting my solution via the form they attached, and I also emailed my solution to archaeology@janestreet.com. I can finally say that I solved a Jane Street puzzle.