"""nanogpt_mini.py, M17: a real (tiny) TRANSFORMER language model from scratch, in PyTorch.

This is the actual architecture behind modern LLMs, shrunk to something a laptop can train in
a minute on a tiny text, inspired by Andrej Karpathy's nanoGPT. It has the real pieces:
token + position embeddings, causal multi-head self-attention, MLP blocks with residuals and
layernorm, a training loop, and text generation. Same idea as tiny_lm.py, now with attention.

Setup:  pip install torch          (use Python 3.10-3.12; torch lags the newest Python)
Run:    python nanogpt_mini.py
"""

import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(0)

# ---- data + char tokenizer --------------------------------------------------
TEXT = "hello world. this is a tiny transformer learning to predict text. " * 80
chars = sorted(set(TEXT))
V = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
data = torch.tensor([stoi[c] for c in TEXT], dtype=torch.long)

# ---- hyperparameters (small on purpose) -------------------------------------
BLOCK = 32        # context length (how many chars the model sees at once)
N_EMB = 64        # embedding size
N_HEAD = 4        # attention heads
N_LAYER = 2       # transformer blocks
STEPS = 300


def get_batch(batch_size=16):
    ix = torch.randint(len(data) - BLOCK, (batch_size,))
    x = torch.stack([data[i:i + BLOCK] for i in ix])
    y = torch.stack([data[i + 1:i + BLOCK + 1] for i in ix])   # next-char targets
    return x, y


class Block(nn.Module):
    """One transformer block: causal self-attention + an MLP, each with a residual + layernorm."""
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(N_EMB)
        self.attn = nn.MultiheadAttention(N_EMB, N_HEAD, batch_first=True)
        self.ln2 = nn.LayerNorm(N_EMB)
        self.mlp = nn.Sequential(nn.Linear(N_EMB, 4 * N_EMB), nn.GELU(), nn.Linear(4 * N_EMB, N_EMB))

    def forward(self, x):
        T = x.size(1)
        mask = torch.triu(torch.ones(T, T, dtype=torch.bool), diagonal=1)  # can't see the future
        a, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=mask, need_weights=False)
        x = x + a                                   # residual
        x = x + self.mlp(self.ln2(x))               # residual
        return x


class MiniGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok = nn.Embedding(V, N_EMB)           # token embeddings
        self.pos = nn.Embedding(BLOCK, N_EMB)       # position embeddings
        self.blocks = nn.Sequential(*[Block() for _ in range(N_LAYER)])
        self.ln = nn.LayerNorm(N_EMB)
        self.head = nn.Linear(N_EMB, V)             # project back to next-char scores

    def forward(self, idx, targets=None):
        T = idx.size(1)
        x = self.tok(idx) + self.pos(torch.arange(T))
        x = self.head(self.ln(self.blocks(x)))      # logits: (batch, T, V)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, V), targets.view(-1))
        return x, loss

    @torch.no_grad()
    def generate(self, idx, n=80):
        for _ in range(n):
            logits, _ = self(idx[:, -BLOCK:])
            probs = F.softmax(logits[:, -1, :], dim=-1)
            idx = torch.cat([idx, torch.multinomial(probs, 1)], dim=1)
        return idx


if __name__ == "__main__":
    model = MiniGPT()
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
    for step in range(STEPS):
        x, y = get_batch()
        _, loss = model(x, y)
        opt.zero_grad(); loss.backward(); opt.step()
        if step % 100 == 0 or step == STEPS - 1:
            print(f"step {step:3d}  loss {loss.item():.3f}")

    start = torch.tensor([[stoi['h']]])
    out = model.generate(start, n=120)[0].tolist()
    print("\ngenerated:\n" + "".join(itos[i] for i in out))
