๐Ÿ’พ Archived View for osanwe.benson.earth โ€บ s โ€บ Cognition โ€บ 61 captured on 2024-12-17 at 09:49:10. Gemini links have been rewritten to link to archived content

View Raw

More Information

-=-=-=-=-=-=-

How Determinantal Point Processes supports Out-Of-Distribution Generalization

Introduction

Deep Neural Networks are impressively good at doing human-ish things which until recently computers just couldn't do. However, one area where humans excel but DNNs are still terrible is out of distribution generalisation (OODG) - that is, applying knowledge learned in one context to another context which has never been encountered before. Now modern AI can sort of do this in practice, since presumably models like Claude and GPT-4 weren't trained on most of the prompts they get from users and they do okay. However, while we don't fully understand what's going on here, we do know that this only happens when DNNs are trained on massive amounts of data - a smaller model which is trained to be really good at one task isn't going to be good at a different task without training, even if the tasks are pretty analogous.

And many people have the idea that while humans can make real analogies, large DNNs which seem like they are doing OODG are really just interpolating. Some people say that DNNs can't make anything 'genuinely new'. This is a fuzzy concept and it's unclear whether humans can make anything genuinely new by the standards that would disqualify a frontier LLM; however, the fact that they are trained on SO MUCH data and might just be interpolating instead of doing the same type of generalization that humans think we're doing is part of why people make that claim.

So people in academia are experimenting with different architectures which are better at OODG. It's worth noting that companies don't seem to be putting resources into this, since they think the best way to improve their models is to scale them and then have their models recursively self improve, instead of having humans make new major architecture improvements.

This is a write up of Mondal et al, August 2024, "Determinantal point process attention over grid cell code supports out of distribution generalization".

DPP-A

Cognitive scientists working on this problem are trying to figure out what allows the brain to do OODG so well. Here are some guesses:

The math for exactly how the abstract representations are chosen actually comes to us from quantum physics. In quantum physics there are these probabilistic models of repulsion called determinantal point processes (DPPs), and when it's used in this cognitive science context we call it DPP-A, or DPP attention. The goal here is to normalize the internal representations of data in DNNs so that they have a nice relational structure which can be applied to all sorts of similarly structured problems.

In order to do this, the DNN doesn't just have to build representations of training data, it has to identify which aspects of the representation could be applied to different problems. This can be done by only processing representations which have the least redundant structure; they are 'pairwise uncorrelated'. A little more formally, we try to maximize the magnitude of the matrix which represents the difference between all the different representations (the covariance matrix). That's all DPP-A really is.

Empirical Tests

This was tested, perhaps somewhat strangely, using the extremely simple case of comparisons between points on the integer plane. You have a point A and a point B and a vector AB between them; and if you have a point C you can add the vector AB to find a point D such that the analogy A:B::C:D holds.

The model would be trained on points in a certain square region on the plane, and then tested on points in a different square region, either a larger one or a region translated diagonally in the positive x and y directions.

The task was, given A, B, and C, to pick D from a multiple choice set. Here's a little ascii graph if it helps:

 Scaling                                Translation
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                          โ”‚           โ”‚            โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”‚
โ”‚                          โ”‚           โ”‚            โ”‚         โ”‚  โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”      โ”‚           โ”‚            โ”‚  Test   โ”‚  โ”‚
โ”‚                   โ”‚      โ”‚           โ”‚            โ”‚         โ”‚  โ”‚
โ”‚        Test       โ”‚      โ”‚           โ”‚            โ”‚         โ”‚  โ”‚
โ”‚                   โ”‚      โ”‚           โ”‚            โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”         โ”‚      โ”‚           โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”               โ”‚
โ”‚         โ”‚         โ”‚      โ”‚           โ”‚         โ”‚               โ”‚
โ”‚  Train  โ”‚         โ”‚      โ”‚           โ”‚  Train  โ”‚               โ”‚
โ”‚         โ”‚         โ”‚      โ”‚           โ”‚         โ”‚               โ”‚
โ”‚         โ”‚         โ”‚      โ”‚           โ”‚         โ”‚               โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Model Architecture

Grid Cell Code

The model that did this task was designed to imitate the grid cell code which has been found in brains to represent abstract relationships spatially. As with the brain's grid cell code, the model used a hexagonal grid (with three basis vectors with angles 0, ฯ€/3, and 2ฯ€/3) over the 2D vector space being mapped. Each grid cell has a firing frequency (a spatial frequency over the hex grid) starting at 0.0028 * 2ฯ€ and frequencies scale up in a geometric progression with a factor of โˆš2. A little equation:

Fโ‚™= Fโ‚€ยท(โˆš2)โฟ

This factor has been shown to somehow minimize the number of variables needed in an encoding, although I don't at all understand why. It's also apparently biologically realistic.

What's easier to understand is that the number of possible states increases exponentially with the number of different frequencies, so even as complexity increases the number of different frequencies used in the grid cell code can stay very low. This model used nine frequencies, just because. It seems like mammals use between 8 and 12 depending (roughly) on how smart they are.

There are also 100 phase offsets, which ensures that every point in the territory is covered by each frequency. So overall you get 100 times 9 is 900 grid cells, each with its own activation. Take any integer point from our picture above, and you get a set of 900 grid cells with activations which form the grid cell embedding of that point.

All this is powerful because, with this representation, translation is just a phase difference, and scaling is just a frequency difference, so patterns are going to look mostly the same in this embedding independent of translation and scaling.

DPP-A again

A grid cell encoding, however, isn't all we need to achieve OODG. The next puzzle piece is attending to the right parts of the embedding, the parts which contain the relational structure of the encoded information. As discussed before, DPP-A looks find these parts by looking for the subset of grid cells which has the most covariance across the training data, looking for cells which have high variance (they change a lot depending on what point is being represented) and low correlation with other grid cells (they are not redundant in representing something that's already captured by another grid cell in the subset).

Test Results

So how does this grid cell encoding with DPP-A perform? Well, on translated and scaled test data, it does great, with near 100% accuracy across the board. Its competitors (a variety of similarly sized standard models or similar models using non-DPP-A methods for selecting a subset of grid cells) suck, with less than 50% accuracy. This pattern holds not just for 'analogy' with vectors but also for multiple choice addition problems. When it comes to OODG for multiplication, DPP-A has a bit more trouble, getting to only 65% accuracy. While this is still impressive compared to 20% accuracy for all other models, it suggests that something about multiplication isn't easily captured in the grid cell embedding.

In fact, in comparing the difference between grid cell embeddings for train and test data for addition and multiplication, the difference was 5 times greater for the multiplication representations.

Conclusion

Large DNNs get away with an architecture that doesn't handle OODG very well by training on a very very large domain. However, with a great deal more work, it's plausible that architecture changes in frontier DNN models could meaningfully improve their ability to do OODG. As for modeling humans, it's not clear whether grid cell code plus DDP-A succeeds in recreating our generalization abilities - it seems to be mostly geared towards representing linear relationships, and to struggle with nonlinear relationships like multiplication. It's quite possible that a bias towards linear relationships is hardwired in humans as well, however. We shall see.

Posted in: s/Cognition

๐Ÿฆ‰ satya [mod]

Dec 16 ยท 2 days ago