Arkved's Tower

I dropped a neural net

Recently, I worked through my first Jane Street puzzle. Not one of their math puzzles on the main website (I still have trouble with those), but one of their ML puzzles on their Hugging Face repo. The current puzzle (as of this writing) is called "I dropped a neural net".

The puzzle

Oh no! I dropped an extremely valuable trading model and it fell apart into linear layers! I need to rebuild it before anyone notices, but I can't remember how these pieces go together, or how it was trained.

All I have left are the pieces of the model and some historical data. Can you help me figure out how to put it back together?

Luckily I still have the source code of the layers that the neural network is made of. They look like this:

class Block(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int):
        super().__init__()
        self.inp = nn.Linear(in_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.out = nn.Linear(hidden_dim, in_dim)

    def forward(self, x):
        residual = x
        x = self.inp(x)
        x = self.activation(x)
        x = self.out(x)
        return residual + x

class LastLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.layer = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        return self.layer(x)

The solution to this puzzle is a permutation. For each index from 0 to 96, you give the index of the piece that is applied in that position.

Summary

To put it simply, the model has been shattered into 97 linear layer files, and we need to reconstruct the exact order of the original forward pass.

The answer to the puzzle is a permutation, but this is not a permutation puzzle in itself.

The problem is that there are 97! ways to order the pieces or roughly 9.62*10151 possible permutations. To put that into perspective, there are only about 1080 atoms in the observable universe. If we had 1080 supercomputers each computing a permutation at Planck time (which is ~1043 checks per second), it would still take us 30 quintillion years to get that permutation, which is over 2 billion times the current age of the universe.

Clearly, we need to look into the model architecture and find some clues. The good news is that the architecture leaks the model structure, and the historical data from the CSV file lets us check whether a proposed reconstruction behaves like the original model.

Step 1: Understanding the model architecture

Refer back to the code provided in the problem statement. Each Block is a residual block. The forward pass is:

def forward(self, x):
        residual = x
        x = self.inp(x)
        x = self.activation(x)
        x = self.out(x)
        return residual + x

Mathematically, this is

xx+WoutReLU(Winx+bin)+bout

In other words,

new state=old state+learned correction

The model keeps a current internal state x. Each residual block then computes a correction to that state and adds it back. The skip connection is the x+... part.

This matters because it constrains the shapes of the layers inside each block. If the model state has a dimension d, then the block has

Linear(dh)ReLULinear(hd)

The first layer expands the state from dimension d to dimension h. The second layer then compresses it back from h to d, so that the correction can be added back to the original d-dimensional state.

This means that if we can infer d and h, we can identify which files are the first half of a residual block and which files are the second half.

Step 2: Inferring the dimensions

The historical CSV is useful even before doing any modeling, as it tells us the input dimension. The CSV contains feature columns of the form:

measurement_0, measurement_1, ..., measurement_47

It tells us that there are 48 input features, therefore the model input state is

x48

The CSV also contains a pred column, which is the scalar output of the original model. So the final output dimension is 1. From the architecture, the final layer must be

Linear(481)

I then inspected the provided .pth files to find the tensor shapes. The weights came in three shapes:

(96,48),(48,96),(1,48)

We know that nn.Linear(in_features, out_features) stores its weight matrix with the shape (out_features, in_features). So these shapes correspond to

(96,48)=Linear(4896)(48,96)=Linear(9648)(1,48)=Linear(481)

This tells us the exact architecture dimensions

in_dim=48,hidden_dim=96,out_dim=1

So every residual block must be

Linear(4896)ReLULinear(9648)residualadd

and the model ends with

Linear(481)

So this leaves us with the following:

48 layers with shape (96,48) first half of residual blocks48 layers with shape (48,96) second half of residual blocks1 layer with shape (1,48) final prediction layer

So now we know

  1. Which piece is the final layer
  2. Which 48 pieces are possible W_in layers
  3. Which 48 pieces are possible W_out layers

Now the next step is to pair each Linear(4896) with its matching Linear(9648) and order the residual blocks that we recover.

Step 3: Pairing the two halves of each residual block

The next step is to determine which W_in belong with each W_out. We know that each residual block contains two trained matrices

Win:4896Wout:9648

In other words:

Win has shape (96,48)Wout has shape (48,96)

Two matrices inside the same block were trained together, so they should have a detectable relationship. A random W_in from one block and a random W_out from another block should fit together less cleanly.

For each candidate pair, I looked at W_out @ W_in. The dimensions multiply out as follows:

(48×96)(96×48)=(48×48)

Essentially we are asking, "If two layers were chained together, what transformation would they apply?" Now of course, this is not the exact residual block because the real block is

x+WoutReLU(Winx+bin)+bout

The product W_out @ W_in ignores the ReLU and the biases. But it still approximates the linear part of the residual update.

For a correct pair, the product W_out @ W_in represents the approximate linear behavior of one coherent residual correction. For an incorrect pair, it combines the first half of one block with the second half of a different block. That product should be less structured.

Empirically, the correct pairs had a much stronger diagonal / structured pattern.

Correct pair heatmap

Incorrect pair heatmap

The correct pair is not a perfect identity matrix, but the diagonal signal is noticeably stronger. The incorrect pair is significantly noisier.

Why do we care about the diagonal structure?

The residual block takes a 48-dimensional state and returns a 48-dimensional correction that gets added back to that same state. So it is natural for the learned correction to preserve some coordinate-wise structure. A diagonal entry in W_out @ W_in measures how much input coordinate i contributes back to output coordinate i through the approximate linearized residual branch. So a stronger diagonal means that the coordinate i of the input tends to influence coordinate i of the residual correction.

Of course, the block can still have off-diagonal interactions, but correct pairs had much more coherent structure than mismatched pairs.

One simple score I used was

def pair_score(W_in, W_out):
    M = W_out @ W_in

    diag = np.diag(M)
    off_diag = M - np.diag(diag)

    diag_strength = np.sum(np.abs(diag))
    off_diag_strength = np.linalg.norm(off_diag)

    return diag_strength / (off_diag_strength + 1e-12)

This score measures diagonal mass relative to off-diagonal mass.

I computed this score for every possible W_in, W_out pair. There are only 48*48=2304 possible candidate pairs, so this step is actually computationally feasible.

Once every candidate pair has a score, we need a one-to-one matching. Each W_in should be paired with exactly one W_out, and each W_out should be used exactly once.

This is a bipartite assignment problem, so I used the Hungarian algorithm

from scipy.optimize import linear_sum_assignment

rows, cols = linear_sum_assignment(-scores)

The negative sign is because linear_sum_assignment solves a minimization problem by default, whereas I wanted to maximize the pairing score.

After this step, the 96 hidden pieces were recovered as 48 residual blocks.

So the problem has been reduced even further. We originally had 97 random files, and now we have 48 residual blocks and one final layer which is appended at the end. Now we have to order all of our residual blocks.

Step 4: Ordering the recovered residual blocks

The final layer is known by its shape. The remaining problem is to determine the order of the 48 residual blocks.

Naively, this is still huge, as there are 48! permutations for these residual blocks. However, this is where the CSV file provided is particularly useful.

The historical data can be used as a scoring function. For any candidate block order, I can reconstruct a full model

x0=input row from CSVxk+1=xk+Wout,kReLU(Win,kxk+bin,k)+bout,kprediction=Wfinalx48+bfinal

Then I compare that reconstructed prediction to the pred column from the CSV. The loss is mean squared error:

MSE=1ni(y^ireconstructedyioriginal)2

The correct ordering should produce MSE equal to zero up to numerical precision.

Seeding the order with the output-layer norm

The heuristic that worked well was sorting blocks by the Frobenius norm of their output projection:

WoutF

The Frobenius norm of a matrix is:

WF=ijWij2

It is just a measure of the overall size of the matrix. For each recovered block, I computed

score = np.linalg.norm(W_out)

and sorted the blocks by that score. It turns out that this was surprisingly close to the true order.

My intuition was that residual networks are not arbitrary compositions of matrices. Each block is a learned perturbation added to the current state, and the scale of that perturbation changes across depth. The norm of W_out is not enough to prove the order, but it gives a very good seed.

This norm-based ordering is the step that makes the local search fairly simple. If the seed were random, adjacent swaps would not be enough. But if the seed is already almost correct, adjacent swaps can repair the remaining local inversions.

Step 5: Finalizing the order with adjacent swaps

After sorting by WoutF, I ran an adjacent-swap local search.

The general procedure for the search is as follows:

  1. Start with the norm-sorted order
  2. Compute the MSE against historical predictions
  3. Try swapping neighboring blocks
  4. Keep the swap if it improves MSE
  5. Repeat until no adjacent swap improves the loss
def adjacent_swap_search(order):
    best = loss(order)

    while True:
        improved = False

        for i in range(len(order) - 1):
            candidate = order.copy()
            candidate[i], candidate[i + 1] = candidate[i + 1], candidate[i]

            candidate_loss = loss(candidate)

            if candidate_loss < best:
                order = candidate
                best = candidate_loss
                improved = True

        if not improved:
            break

    return order, best

This is somewhat like bubble sort, except the comparison is not about checking if the number at index i is greater than the number at index i+1. Instead the comparison its making is checking if swapping two neighboring residual blocks makes the reconstructed model closer to the original model.

The MSE curve looked like this

MSE during adjacent-swap refinement

Most accepted swaps gradually reduce the error. Then near the end, there is a huge drop.

That final drop is the interesting part. It means the model was already almost structurally correct, but one critical local inversion was still wrong. Because residual blocks are applied sequentially, an ordering mistake changes the internal state, and all later blocks operate on that changed state. Once the last important local inversion was fixed, the entire forward pass lined up with the original model and the error collapsed to numerical precision.

The final MSE was 3.369516376528804e-14

Step 6: Validation

To validate my findings, I compared the original predictions to the reconstructed predictions.

Provided vs reconstructed predictions

If the model were reconstructed incorrectly, the points would drift away from the diagonal. Instead, they lie pretty much exactly on the identity line. Thus, I can conclude that the model has been reconstructed.

Step 7: Getting the final permutation

As stated previously, each residual block has two piece indices. The first is an expand piece and the second is a compress piece. So I walked through the recovered block order and emitted:

expand_1, compress_1, expand_2, compress_2, ..., expand_48, compress_48

Then I appended the final Linear(481) layer to the end of the list, and I had my finalized permutation. I submitted my comma-separated permutation to the repo, and was met with the following message.

Success message

Final thoughts

What I liked about this puzzle is that the model is broken, but not in a way that destroys all information. The CSV reveals the input and output dimensions. The tensor shapes reveal the roles of the layers. The matrix products reveal which halves were trained together. The historical predictions reveal whether the ordering is correct.

I ended up submitting my solution via the form they attached, and I also emailed my solution to archaeology@janestreet.com. I can finally say that I solved a Jane Street puzzle.