I decided I need to play more. Specifically, I wanted an excuse to use Google’s new T5Gemma2 Encoder-Decoder model. Except I did not want to use the decoder. I split the model and added a linear classifier head to be trained with BCE. I decided to use an old Kaggle competition dataset: https://www.kaggle.com/competitions/plant-pathology-2020-fgvc7/overview

The goal was not to beat the leaderboard but to walk through training, data manipulation and handling of low sample datasets. One of the most common criticisms of Kaggle is that entries are overfit. I won’t be submitting but I did want to set an accuracy bound on the top submissions so we had a rough target. The top entry is \~0.99 AUC but I recorded both F1 and Accuracy so I wanted benchmarks.

In addition to a 0.99 AUC, I found one of the top entries on Github. It had 95.002% of the test columns with single class confidence. So our upper bound is \~95% accuracy. Our lower bound without destroying the AUC would need to be greater than 85%. A key note on this is that our experiment supports this range. When I had a static AUC, we often had accuracy swing \~2% between evaluation runs.

The first run was a naive baseline. I wanted to set up the environment and check basic assumptions. I rented an A6000 on Prime Intellect. We ran with a batch of 8, LR of 3e-4, a cosine LR decay, and 5 epochs. I used the 270m encoder of the T5Gemma2. The train dataset was used since it is the only labelled dataset. I split it 2:1 train:val. Only the classifier head was trained. We did not touch the weights of the encoder. We ended with a 86% acc, 0.514 F1 and 0.924 AUC.

I then wanted to look at gradient noise (what batch size should we use) and then the impact of extending the dataset with augmentations. I used flips, reversals, shifts, blurs and contrast adjustments. Unsurprisingly, the encoder is well trained and our augmentations had minimal effects. If the encoder was untrained then we would expect these augmentations to influence our results. However, my gradient noise showed the optimal batch size for SNR was \~1.

Our batch size change then made an immediate impact. My f1 increased to 0.61, accuracy increased to 0.915 and my auc increased to 0.945. Again, there were some periods where AUC was \~0.945 and our accuracy was lower (\~0.900). The SNR and low batch size indicates that the dataset is pretty clean or our base encoder is rather expressive. I think it was a bit of both. Regardless, the next step is to see if we can clean up the dataset a little bit more.

In the next run, I used the encoder to create embeddings of our training data. I then used cosine similarity to determine if any crossed a 0.995 threshold (they were the same). We tagged 24 pairs which corresponded to 23 images. These images and labels were removed from training. This could have been doubly diseased leaves, accidental duplications or duplicates mislabelled. My deduplication did not explicitly improve mislabelled data. Again, we saw training improvements. My f1 increased to 0.63, the accuracy increased to 92.6% and auc was 0.952.

In terms of data, the next question was whether we could increase our gains by tuning noisy labels. One method would be to use gradient noise to determine image-label pairs which were difficult to learn. Some of these are difficult but often they are mislabelled which is why the model struggles to learn. The other option is to use K-folds where you soften the labels by using a trained model. Since this is a toy project and I am out of practice identifying apple diseases, I decided to use 5 folds. I trained 5 models to label a unique, held out ⅕ sample of the data. I then combined these new labels with the “ground truth” labels at a ratio of 3:7. Explicitly, the trained labels constituted 0.3 proportion of the label we used for training on the full set. This did not noticeably improve our performance.

The final implementation for the 270m model was using test-time augmentation (TTA) to improve our AUC. After our training sample augmentation had no noticeable effect, I did not expect TTA to improve our scores. However, I wanted to give it a fair shot. I implemented 5 TTA modifications which were a mix of horizontal/vertical flips and brightness augmentation. It did not meaningfully increase our performance.

Scaling is always fun so I ran one training run with the 4B model using my optimal setup (deduplication, batch size of 1). What was interesting is that the larger model removed only 21 (as opposed to 23 samples) and the optimal batch size was determined to be higher at 3. I ran with 1 because our noise test and batch size set are not wired together. I did leave capacity on the table by not increasing our batch size.

If this was more than a toy demo, I would begin digging into these training examples:
[train] step=2410 epoch=3 loss=0.1239 lr=2.00e-04
[train] step=2420 epoch=3 loss=0.1199 lr=1.99e-04
[train] step=2430 epoch=3 loss=0.4373 lr=1.98e-04
[train] step=2440 epoch=3 loss=0.0358 lr=1.97e-04
[train] step=2450 epoch=3 loss=0.8764 lr=1.96e-04
[train] step=2460 epoch=3 loss=0.3298 lr=1.96e-04
[train] step=2470 epoch=3 loss=0.7009 lr=1.95e-04
[train] step=2480 epoch=3 loss=0.0971 lr=1.94e-04
[train] step=2490 epoch=3 loss=0.6982 lr=1.93e-04
[train] step=2500 epoch=3 loss=0.0729 lr=1.92e-04
[train] step=2510 epoch=3 loss=0.1496 lr=1.91e-04
[train] step=2520 epoch=3 loss=0.1850 lr=1.91e-04
[train] step=2530 epoch=3 loss=0.4256 lr=1.90e-04
[train] step=2540 epoch=3 loss=0.0152 lr=1.89e-04

Why are these losses so high? Our validation loss is \~0.19 so there is something interesting with the data. I could likely push the results even higher. However, scaling also gives us a decent bump.

My final f1 is 0.663 with an accuracy of 94.4% and auc 0.969. This is only good for \~450 on the leaderboard but we’re effectively saturating the dataset. I could continue to play with the data to eke out the final offerings but that is more time than I will dedicate today. The intent was to take the new T5 model for a spin and sharpen my classification skills.

Here is the Validation AUC chart from WandB.

Addendum: I went back and decided to see if I could score at the top of the leaderboard. Adding an MLP head and decreasing our dropout to 0 allowed me to get an AUC of 0.9872. Effectively saturating this benchmark completely. This is with the 270m parameter encoder. The final run and performance are below:

uv run python3 vision_dis.py \
--data-dir ./pp2020/unzipped --batch-size 1 \
--eval-every 200 --epochs 5 --remove-duplicates \
--dedupe-threshold 0.995 --classifier-head mlp --lr 1e-3 --classifier-dropout 0
[eval] step=4200 loss=0.1057 f1=0.6880 acc=0.9596 auc=0.9872

The code can be found below.

from __future__ import annotations

import argparse
import csv
import json
import math
import os
import random
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset, DataLoader

import wandb
from transformers import (
    AutoConfig,
    AutoImageProcessor,
    AutoProcessor,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    get_cosine_schedule_with_warmup,
)

# Optional AUROC support
try:
    from sklearn.metrics import roc_auc_score
except Exception:
    roc_auc_score = None


LABEL_COLS = ["healthy", "multiple_diseases", "rust", "scab"]


@dataclass
class EncoderClassifierOutput:
    loss: Optional[torch.Tensor]
    logits: torch.Tensor


def build_pp2020_augmentations():
    try:
        import albumentations as A
        import cv2
    except Exception as e:
        raise RuntimeError(
            "PP2020 augmentations require `albumentations` and `opencv-python`."
        ) from e

    if hasattr(A, "RandomBrightness") and hasattr(A, "RandomContrast"):
        brightness_contrast = A.OneOf(
            [A.RandomBrightness(limit=0.1, p=1), A.RandomContrast(limit=0.1, p=1)],
            p=1,
        )
    else:
        brightness_contrast = A.RandomBrightnessContrast(
            brightness_limit=0.1, contrast_limit=0.1, p=1
        )

    # Resize/normalize are handled by the processor to avoid double preprocessing.
    return A.Compose(
        [
            brightness_contrast,
            A.OneOf(
                [
                    A.MotionBlur(blur_limit=3),
                    A.MedianBlur(blur_limit=3),
                    A.GaussianBlur(blur_limit=3),
                ],
                p=0.5,
            ),
            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.2,
                scale_limit=0.2,
                rotate_limit=20,
                interpolation=cv2.INTER_LINEAR,
                border_mode=cv2.BORDER_REFLECT_101,
                p=1.0,
            ),
        ]
    )


def build_pp2020_tta_augmentations():
    try:
        import albumentations as A
    except Exception as e:
        raise RuntimeError(
            "TTA augmentations require `albumentations`."
        ) from e

    # Keep TTA light to avoid distribution shift.
    return A.Compose(
        [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=0.3),
        ]
    )


def apply_albumentations(augment, img: Image.Image) -> Image.Image:
    arr = np.array(img)
    out = augment(image=arr)["image"]
    if out.dtype != np.uint8:
        max_val = float(out.max()) if out.size else 1.0
        if max_val <= 1.0:
            out = (np.clip(out, 0.0, 1.0) * 255.0).round()
        else:
            out = np.clip(out, 0.0, 255.0)
        out = out.astype(np.uint8)
    return Image.fromarray(out)


def infer_encoder_hidden_size(encoder, config) -> int:
    for attr in ("hidden_size", "d_model"):
        val = getattr(encoder.config, attr, None)
        if isinstance(val, int):
            return val
    text_config = getattr(encoder.config, "text_config", None)
    if text_config is not None:
        val = getattr(text_config, "hidden_size", None)
        if isinstance(val, int):
            return val
    for attr in ("hidden_size", "d_model"):
        val = getattr(config, attr, None)
        if isinstance(val, int):
            return val
    text_config = getattr(config, "text_config", None)
    if text_config is not None:
        val = getattr(text_config, "hidden_size", None)
        if isinstance(val, int):
            return val
    get_embeddings = getattr(encoder, "get_input_embeddings", None)
    if callable(get_embeddings):
        emb = get_embeddings()
        if emb is not None:
            if hasattr(emb, "embedding_dim"):
                return int(emb.embedding_dim)
            if hasattr(emb, "weight") and emb.weight is not None:
                return int(emb.weight.shape[1])
    raise ValueError("Could not infer encoder hidden size for classification head.")


class EncoderClassifier(nn.Module):
    def __init__(
        self,
        model_id: str,
        num_labels: int,
        classifier_dropout: float = 0.1,
        classifier_head: str = "linear",
        classifier_hidden_dim: int = 0,
    ):
        super().__init__()
        self.backbone = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        self.num_labels = num_labels
        encoder = self.backbone.get_encoder()
        hidden_size = infer_encoder_hidden_size(encoder, self.backbone.config)
        self.dropout = nn.Dropout(p=classifier_dropout)
        backbone_dtype = next(self.backbone.parameters()).dtype
        head = classifier_head.lower()
        if head == "linear":
            self.classifier = nn.Linear(hidden_size, num_labels)
        elif head == "mlp":
            mlp_dim = classifier_hidden_dim if classifier_hidden_dim > 0 else hidden_size
            self.classifier = nn.Sequential(
                nn.Linear(hidden_size, mlp_dim),
                nn.GELU(),
                nn.Dropout(p=classifier_dropout),
                nn.Linear(mlp_dim, num_labels),
            )
        else:
            raise ValueError(f"Unknown classifier_head={classifier_head!r}.")
        self.classifier.to(dtype=backbone_dtype)
        self.loss_fn = nn.BCEWithLogitsLoss()

    def _classifier_dtype(self) -> torch.dtype:
        for p in self.classifier.parameters():
            return p.dtype
        return torch.float32

    def _attention_mask_dtype(self) -> torch.dtype:
        if torch.is_autocast_enabled():
            if torch.cuda.is_available():
                get_dtype = getattr(torch, "get_autocast_dtype", None)
                if callable(get_dtype):
                    return get_dtype("cuda")
                get_dtype = getattr(torch, "get_autocast_gpu_dtype", None)
                if callable(get_dtype):
                    return get_dtype()
                if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
                    return torch.bfloat16
                return torch.float16
            get_dtype = getattr(torch, "get_autocast_dtype", None)
            if callable(get_dtype):
                return get_dtype("cpu")
            get_dtype = getattr(torch, "get_autocast_cpu_dtype", None)
            if callable(get_dtype):
                return get_dtype()
        emb = self.backbone.get_input_embeddings()
        if emb is not None and hasattr(emb, "weight") and emb.weight is not None:
            return emb.weight.dtype
        return torch.float32

    def _build_full_attention_mask(
        self, attention_mask: torch.Tensor, dtype: torch.dtype
    ) -> torch.Tensor:
        # Build an additive mask: 0 for keep, large negative for masked keys.
        mask = (1.0 - attention_mask.to(dtype)) * torch.finfo(dtype).min
        return mask[:, None, None, :].expand(-1, 1, attention_mask.shape[1], -1)

    def _pool(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor:
        if attention_mask is None:
            return hidden_states.mean(dim=1)
        mask = attention_mask.unsqueeze(-1).to(hidden_states.dtype)
        return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        pixel_values: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> EncoderClassifierOutput:
        orig_attention_mask = attention_mask
        attn_mask_mapping = None
        if attention_mask is not None:
            mask_dtype = self._attention_mask_dtype()
            full_mask = self._build_full_attention_mask(attention_mask, mask_dtype)
        else:
            full_mask = None
        # Avoid transformer mask helpers that require torch>=2.6 by passing a dict.
        attn_mask_mapping = {"full_attention": full_mask, "sliding_attention": full_mask}

        encoder_kwargs = {}
        if input_ids is not None:
            encoder_kwargs["input_ids"] = input_ids
        encoder_kwargs["attention_mask"] = attn_mask_mapping
        if pixel_values is not None:
            encoder_kwargs["pixel_values"] = pixel_values
        if "position_ids" in kwargs and kwargs["position_ids"] is not None:
            encoder_kwargs["position_ids"] = kwargs["position_ids"]
        if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
            encoder_kwargs["inputs_embeds"] = kwargs["inputs_embeds"]

        encoder = self.backbone.get_encoder()
        enc_out = encoder(return_dict=True, **encoder_kwargs)
        pooled = self._pool(enc_out.last_hidden_state, orig_attention_mask)
        pooled = pooled.to(self._classifier_dtype())
        logits = self.classifier(self.dropout(pooled))
        loss = self.loss_fn(logits.float(), labels) if labels is not None else None
        return EncoderClassifierOutput(loss=loss, logits=logits)


def read_train_csv(train_csv: Path) -> List[Dict[str, str]]:
    with train_csv.open("r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f)
        rows = list(reader)
    if not rows:
        raise ValueError(f"Empty train.csv: {train_csv}")
    for col in ["image_id"] + LABEL_COLS:
        if col not in rows[0]:
            raise ValueError(f"Expected column '{col}' in {train_csv}, got columns={list(rows[0].keys())}")
    return rows


def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def resolve_pp2020_image(images_dir: Path, image_id: str) -> Path:
    # In PP2020, image files are typically images/<image_id>.jpg
    cand = images_dir / f"{image_id}.jpg"
    if cand.exists():
        return cand
    # fallback: try png/jpeg variants
    for ext in [".jpeg", ".png", ".bmp", ".webp"]:
        cand2 = images_dir / f"{image_id}{ext}"
        if cand2.exists():
            return cand2
    raise FileNotFoundError(f"Could not find image for image_id={image_id} under {images_dir}")


class PP2020Dataset(Dataset):
    def __init__(
        self,
        rows: List[Dict[str, str]],
        images_dir: Path,
        processor,
        max_items: int = 0,
        augment=None,
        use_soft_labels: bool = False,
    ):
        self.rows = rows[: max_items] if max_items and max_items > 0 else rows
        self.images_dir = images_dir
        self.processor = processor
        self.prompt = "<start_of_image>"
        self.augment = augment
        self.use_soft_labels = use_soft_labels

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        r = self.rows[idx]
        img_path = resolve_pp2020_image(self.images_dir, r["image_id"])

        try:
            img = Image.open(img_path).convert("RGB")
            if self.augment is not None:
                img = apply_albumentations(self.augment, img)
        except Exception:
            # Defensive fallback to a black image (won't crash training)
            img = Image.fromarray(np.zeros((896, 896, 3), dtype=np.uint8), mode="RGB")

        enc = self.processor(text=self.prompt, images=img, return_tensors="pt")
        batch = {k: v.squeeze(0) for k, v in enc.items() if isinstance(v, torch.Tensor)}

        if self.use_soft_labels and "soft_labels" in r:
            y_vals = r["soft_labels"]
        else:
            y_vals = [float(r[c]) for c in LABEL_COLS]
        y = torch.tensor(y_vals, dtype=torch.float32)
        batch["labels"] = y
        return batch


class ImageOnlyDataset(Dataset):
    def __init__(
        self,
        rows: List[Dict[str, str]],
        images_dir: Path,
        processor,
        max_items: int = 0,
        augment=None,
    ):
        self.rows = rows[: max_items] if max_items and max_items > 0 else rows
        self.images_dir = images_dir
        self.processor = processor
        self.augment = augment
        self.prompt = "<start_of_image>"

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
        r = self.rows[idx]
        img_path = resolve_pp2020_image(self.images_dir, r["image_id"])

        try:
            img = Image.open(img_path).convert("RGB")
            if self.augment is not None:
                img = apply_albumentations(self.augment, img)
        except Exception:
            img = Image.fromarray(np.zeros((896, 896, 3), dtype=np.uint8), mode="RGB")

        enc = self.processor(text=self.prompt, images=img, return_tensors="pt")
        pixel_values = enc["pixel_values"].squeeze(0)
        return pixel_values, r["image_id"]


def collate_fn(examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    # Token padding (prompt is tiny, but pad defensively)
    input_ids = [e["input_ids"] for e in examples]
    attn = [e.get("attention_mask", torch.ones_like(e["input_ids"])) for e in examples]
    pix = [e["pixel_values"] for e in examples]
    labels = [e["labels"] for e in examples]

    max_len = max(x.shape[-1] for x in input_ids)

    def pad_1d(x: torch.Tensor, pad_val: int = 0) -> torch.Tensor:
        if x.shape[-1] == max_len:
            return x
        out = torch.full((max_len,), pad_val, dtype=x.dtype)
        out[: x.shape[-1]] = x
        return out

    return {
        "input_ids": torch.stack([pad_1d(x, 0) for x in input_ids], dim=0),
        "attention_mask": torch.stack([pad_1d(x, 0) for x in attn], dim=0),
        "pixel_values": torch.stack(pix, dim=0),
        "labels": torch.stack(labels, dim=0),
    }


def freeze_all_but_classifier(model: nn.Module) -> List[str]:
    for p in model.parameters():
        p.requires_grad = False

    trainable: List[str] = []
    # HF convention for sequence classification heads varies; include several patterns.
    head_patterns = ["classifier", "classification_head", "score"]
    for name, p in model.named_parameters():
        if any(pat in name for pat in head_patterns):
            p.requires_grad = True
            trainable.append(name)

    # Fallback: unfreeze last Linear module found
    if not trainable:
        last_linear_prefix = None
        for name, m in model.named_modules():
            if isinstance(m, nn.Linear):
                last_linear_prefix = name
        if last_linear_prefix is not None:
            for name, p in model.named_parameters():
                if name.startswith(last_linear_prefix):
                    p.requires_grad = True
                    trainable.append(name)

    if not trainable:
        raise RuntimeError("Failed to identify classification head parameters to train.")
    return trainable


def unfreeze_all(model: nn.Module) -> List[str]:
    trainable: List[str] = []
    for name, p in model.named_parameters():
        p.requires_grad = True
        trainable.append(name)
    return trainable


@torch.no_grad()
def compute_metrics_multilabel(logits: torch.Tensor, labels: torch.Tensor) -> Dict[str, float]:
    """
    logits: (N, 4), labels: (N, 4) float {0,1}
    """
    probs = torch.sigmoid(logits.float()).cpu().numpy()
    y = labels.cpu().numpy().astype(np.int32)

    # Macro-F1 at threshold 0.5
    pred = (probs >= 0.5).astype(np.int32)
    eps = 1e-9
    tp = (pred & y).sum(axis=0).astype(np.float64)
    fp = (pred & (1 - y)).sum(axis=0).astype(np.float64)
    fn = ((1 - pred) & y).sum(axis=0).astype(np.float64)
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = (2.0 * precision * recall) / (precision + recall + eps)
    f1_macro = float(np.mean(f1))

    # Exact match (all 4 labels match)
    exact = float(np.mean(np.all(pred == y, axis=1)))

    # Label-wise accuracy (fraction of correctly predicted labels)
    accuracy = float(np.mean(pred == y))

    # Mean ROC-AUC across 4 columns (Kaggle metric)
    mean_auc = float("nan")
    if roc_auc_score is not None:
        aucs = []
        for k in range(y.shape[1]):
            # roc_auc requires both classes present
            if len(np.unique(y[:, k])) < 2:
                continue
            aucs.append(roc_auc_score(y[:, k], probs[:, k]))
        mean_auc = float(np.mean(aucs)) if aucs else float("nan")

    return {
        "val/f1_macro": f1_macro,
        "val/exact_match": exact,
        "val/accuracy": accuracy,
        "val/mean_roc_auc": mean_auc,
    }


@torch.no_grad()
def evaluate(
    model,
    loader,
    device,
    tta_loader=None,
    tta_repeats: int = 0,
) -> Dict[str, float]:
    def _eval_pass(pass_loader):
        losses: List[float] = []
        all_logits: List[torch.Tensor] = []
        all_labels: List[torch.Tensor] = []

        for batch in pass_loader:
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            out = model(**batch)
            losses.append(float(out.loss.detach().cpu()))
            all_logits.append(out.logits.detach().float().cpu())
            all_labels.append(batch["labels"].detach().cpu())

        if not losses:
            return None, None, None
        return losses, torch.cat(all_logits, dim=0), torch.cat(all_labels, dim=0)

    model.eval()
    if tta_loader is not None and tta_repeats > 0:
        total_logits = None
        total_losses: List[float] = []
        labels = None
        for _ in range(tta_repeats):
            losses, logits, pass_labels = _eval_pass(tta_loader)
            if losses is None or logits is None or pass_labels is None:
                continue
            total_losses.extend(losses)
            if labels is None:
                labels = pass_labels
            if total_logits is None:
                total_logits = logits
            else:
                total_logits += logits

        if total_logits is None or labels is None:
            return {"val/loss": float("nan")}

        logits = total_logits / float(max(1, tta_repeats))
        metrics = {
            "val/loss": float(np.mean(total_losses)) if total_losses else float("nan"),
            "val/tta_repeats": float(tta_repeats),
        }
        metrics.update(compute_metrics_multilabel(logits, labels))
        return metrics

    losses, logits, labels = _eval_pass(loader)
    if losses is None or logits is None or labels is None:
        return {"val/loss": float("nan")}

    metrics = {"val/loss": float(np.mean(losses))}
    metrics.update(compute_metrics_multilabel(logits, labels))
    return metrics


def estimate_gradient_noise_scale(
    model,
    loader,
    device,
    steps: int,
    amp_device: str,
    amp_dtype: torch.dtype,
) -> Tuple[float, float]:
    params = [p for p in model.parameters() if p.requires_grad]
    if not params:
        return float("nan"), float("nan")

    was_training = model.training
    model.train()
    data_iter = iter(loader)
    gns_values: List[float] = []

    def next_batch():
        nonlocal data_iter
        try:
            return next(data_iter)
        except StopIteration:
            data_iter = iter(loader)
            return next(data_iter)

    def grad_vector(batch) -> torch.Tensor:
        model.zero_grad(set_to_none=True)
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        with torch.amp.autocast(amp_device, enabled=amp_device == "cuda", dtype=amp_dtype):
            out = model(**batch)
            if out.loss is None:
                raise RuntimeError("Gradient noise scale estimation requires labels.")
            loss = out.loss
        loss.backward()
        grads = [p.grad.detach().float().flatten() for p in params if p.grad is not None]
        if not grads:
            raise RuntimeError("No gradients collected for gradient noise scale estimation.")
        return torch.cat(grads)

    for _ in range(max(1, steps)):
        b1 = next_batch()
        b2 = next_batch()
        g1 = grad_vector(b1)
        g2 = grad_vector(b2)
        g = 0.5 * (g1 + g2)
        denom = 2.0 * (g.pow(2).sum() + 1e-12)
        num = (g1 - g2).pow(2).sum()
        batch_size = int(b1["labels"].shape[0])
        gns = (num / denom) * batch_size
        gns_values.append(float(gns))

    model.zero_grad(set_to_none=True)
    if not was_training:
        model.eval()

    return float(np.mean(gns_values)), float(np.std(gns_values))


@torch.no_grad()
def compute_image_embeddings(
    rows: List[Dict[str, str]],
    images_dir: Path,
    processor,
    model,
    device,
    batch_size: int,
    num_workers: int,
    amp_device: str,
    amp_dtype: torch.dtype,
) -> Tuple[torch.Tensor, List[str]]:
    encoder = model.backbone.get_encoder()
    if not hasattr(encoder, "get_image_features"):
        raise RuntimeError("Encoder does not expose get_image_features for duplicate checking.")

    ds = ImageOnlyDataset(rows, images_dir, processor, max_items=0, augment=None)
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=False,
    )

    was_training = model.training
    model.eval()
    embeddings: List[torch.Tensor] = []
    image_ids: List[str] = []

    for pixel_values, ids in loader:
        pixel_values = pixel_values.to(device, non_blocking=True)
        with torch.amp.autocast(amp_device, enabled=amp_device == "cuda", dtype=amp_dtype):
            feats = encoder.get_image_features(pixel_values)
        pooled = feats.mean(dim=1)
        embeddings.append(pooled.float().cpu())
        image_ids.extend(list(ids))

    if was_training:
        model.train()

    return torch.cat(embeddings, dim=0), image_ids


def find_duplicate_pairs(
    embeddings: torch.Tensor,
    image_ids: List[str],
    threshold: float,
    max_pairs: int,
    chunk_size: int,
) -> Tuple[List[Dict[str, float]], int]:
    if embeddings.numel() == 0:
        return [], 0

    emb = F.normalize(embeddings.float(), dim=1)
    total_pairs = 0
    pairs: List[Dict[str, float]] = []
    n = emb.shape[0]

    for start in range(0, n, max(1, chunk_size)):
        end = min(start + max(1, chunk_size), n)
        sims = emb[start:end] @ emb.T
        for row in range(end - start):
            idx = start + row
            sim_row = sims[row]
            if idx + 1 < sim_row.numel():
                sim_row[: idx + 1] = -1.0
            hits = (sim_row >= threshold).nonzero(as_tuple=False).squeeze(1)
            total_pairs += int(hits.numel())
            if len(pairs) >= max_pairs:
                continue
            for j in hits.tolist():
                if len(pairs) >= max_pairs:
                    break
                pairs.append(
                    {
                        "image_id_a": image_ids[idx],
                        "image_id_b": image_ids[j],
                        "cosine_sim": float(sim_row[j].item()),
                    }
                )

    return pairs, total_pairs


def scan_duplicate_pairs(
    embeddings: torch.Tensor,
    image_ids: List[str],
    threshold: float,
    max_pairs: int,
    chunk_size: int,
    build_groups: bool = False,
) -> Tuple[List[Dict[str, float]], int, Optional[Dict[int, List[int]]]]:
    if embeddings.numel() == 0:
        return [], 0, {} if build_groups else None

    emb = F.normalize(embeddings.float(), dim=1)
    total_pairs = 0
    pairs: List[Dict[str, float]] = []
    n = emb.shape[0]

    parent = list(range(n)) if build_groups else None

    def find(x: int) -> int:
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    def union(a: int, b: int) -> None:
        ra, rb = find(a), find(b)
        if ra != rb:
            if ra < rb:
                parent[rb] = ra
            else:
                parent[ra] = rb

    for start in range(0, n, max(1, chunk_size)):
        end = min(start + max(1, chunk_size), n)
        sims = emb[start:end] @ emb.T
        for row in range(end - start):
            idx = start + row
            sim_row = sims[row]
            if idx + 1 < sim_row.numel():
                sim_row[: idx + 1] = -1.0
            hits = (sim_row >= threshold).nonzero(as_tuple=False).squeeze(1)
            total_pairs += int(hits.numel())
            if build_groups and hits.numel() > 0:
                for j in hits.tolist():
                    union(idx, j)
            if len(pairs) < max_pairs:
                for j in hits.tolist():
                    if len(pairs) >= max_pairs:
                        break
                    pairs.append(
                        {
                            "image_id_a": image_ids[idx],
                            "image_id_b": image_ids[j],
                            "cosine_sim": float(sim_row[j].item()),
                        }
                    )

    groups = None
    if build_groups:
        groups = {}
        for i in range(n):
            root = find(i)
            groups.setdefault(root, []).append(i)

    return pairs, total_pairs, groups


def build_kfold_splits(n_items: int, kfolds: int, seed: int) -> List[np.ndarray]:
    if kfolds < 2:
        raise ValueError("kfolds must be >= 2 to build folds.")
    if kfolds > n_items:
        raise ValueError(f"kfolds={kfolds} is greater than dataset size={n_items}.")
    rng = np.random.RandomState(seed)
    idxs = np.arange(n_items)
    rng.shuffle(idxs)
    return [fold for fold in np.array_split(idxs, kfolds) if len(fold) > 0]


def _get_int_attr(obj, names: List[str]) -> Optional[int]:
    for name in names:
        val = getattr(obj, name, None)
        if isinstance(val, int) and val > 0:
            return val
    return None


class ImageTokenProcessor:
    def __init__(self, image_processor, image_token_id: int, mm_tokens_per_image: int):
        self.image_processor = image_processor
        self.image_token_id = int(image_token_id)
        self.mm_tokens_per_image = int(mm_tokens_per_image)

    def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
        if images is None:
            raise ValueError("images must be provided to the processor.")
        enc = self.image_processor(images=images, return_tensors=return_tensors)
        pixel_values = enc.get("pixel_values")
        if pixel_values is None:
            raise ValueError("image processor did not return pixel_values.")
        batch_size = int(pixel_values.shape[0])
        input_ids = torch.full(
            (batch_size, self.mm_tokens_per_image),
            self.image_token_id,
            dtype=torch.long,
        )
        attention_mask = torch.ones((batch_size, self.mm_tokens_per_image), dtype=torch.long)
        enc["input_ids"] = input_ids
        enc["attention_mask"] = attention_mask
        return enc


def build_fallback_processor(model_id: str) -> ImageTokenProcessor:
    config = AutoConfig.from_pretrained(model_id)
    encoder_cfg = getattr(config, "encoder", None) or config

    image_token_id = _get_int_attr(encoder_cfg, ["image_token_id", "image_token_index"])
    mm_tokens_per_image = _get_int_attr(encoder_cfg, ["mm_tokens_per_image"])

    if image_token_id is None:
        image_token_id = _get_int_attr(config, ["image_token_id", "image_token_index"])
    if mm_tokens_per_image is None:
        mm_tokens_per_image = _get_int_attr(config, ["mm_tokens_per_image"])

    if image_token_id is None:
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
            for tok in ["<image>", "<start_of_image>"]:
                tok_id = tokenizer.convert_tokens_to_ids(tok)
                if isinstance(tok_id, int) and tok_id != tokenizer.unk_token_id:
                    image_token_id = tok_id
                    break
        except Exception:
            image_token_id = None

    if image_token_id is None or mm_tokens_per_image is None:
        raise RuntimeError(
            "Could not resolve image_token_id/mm_tokens_per_image for fallback processor."
        )

    image_processor = AutoImageProcessor.from_pretrained(model_id)
    print(
        "[processor] using image-only fallback processor; text tokens will be ignored.",
        flush=True,
    )
    return ImageTokenProcessor(image_processor, image_token_id, mm_tokens_per_image)


def load_processor(model_id: str, use_fast: bool):
    if use_fast:
        try:
            return AutoProcessor.from_pretrained(model_id, use_fast=True)
        except Exception as exc:
            print(
                f"[processor] fast processor failed ({exc.__class__.__name__}), falling back to slow.",
                flush=True,
            )

    try:
        return AutoProcessor.from_pretrained(model_id, use_fast=False)
    except Exception as exc:
        print(
            f"[processor] AutoProcessor failed ({exc.__class__.__name__}); using fallback.",
            flush=True,
        )
        return build_fallback_processor(model_id)


def train_simple(
    model,
    train_loader,
    args,
    device,
    amp_device: str,
    amp_dtype: torch.dtype,
    use_grad_scaler: bool,
) -> None:
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optim = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
    steps_per_epoch = math.ceil(len(train_loader) / max(1, args.grad_accum))
    total_steps = steps_per_epoch * args.epochs
    warmup_steps = int(total_steps * args.warmup_ratio)
    sched = get_cosine_schedule_with_warmup(
        optim, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
    scaler = torch.amp.GradScaler(amp_device, enabled=use_grad_scaler)

    for _ in range(1, args.epochs + 1):
        model.train()
        optim.zero_grad(set_to_none=True)
        for step, batch in enumerate(train_loader, start=1):
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
            with torch.amp.autocast(
                amp_device,
                enabled=amp_device == "cuda",
                dtype=amp_dtype,
            ):
                out = model(**batch)
                loss = out.loss / max(1, args.grad_accum)
            if use_grad_scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            if step % args.grad_accum == 0:
                if use_grad_scaler:
                    scaler.step(optim)
                    scaler.update()
                else:
                    optim.step()
                optim.zero_grad(set_to_none=True)
                sched.step()


@torch.no_grad()
def predict_probs(
    model,
    loader,
    device,
    amp_device: str,
    amp_dtype: torch.dtype,
) -> torch.Tensor:
    model.eval()
    probs: List[torch.Tensor] = []
    for batch in loader:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        with torch.amp.autocast(
            amp_device,
            enabled=amp_device == "cuda",
            dtype=amp_dtype,
        ):
            out = model(**batch)
            probs.append(torch.sigmoid(out.logits.float()).cpu())
    if not probs:
        return torch.empty((0, len(LABEL_COLS)), dtype=torch.float32)
    return torch.cat(probs, dim=0)


def build_oof_soft_labels(
    rows: List[Dict[str, str]],
    images_dir: Path,
    processor,
    args,
    device,
    amp_device: str,
    amp_dtype: torch.dtype,
    use_grad_scaler: bool,
    train_augment,
    log_cli,
) -> np.ndarray:
    folds = build_kfold_splits(len(rows), args.kfolds, args.seed)
    oof_probs = np.zeros((len(rows), len(LABEL_COLS)), dtype=np.float32)
    for fold_idx, val_idx in enumerate(folds, start=1):
        train_idx = np.setdiff1d(np.arange(len(rows)), val_idx)
        fold_train_rows = [rows[i] for i in train_idx]
        fold_val_rows = [rows[i] for i in val_idx]

        fold_model = EncoderClassifier(
            model_id=args.model_id,
            num_labels=len(LABEL_COLS),
            classifier_dropout=args.classifier_dropout,
            classifier_head=args.classifier_head,
            classifier_hidden_dim=args.classifier_mlp_dim,
        )
        if args.train_full_model:
            unfreeze_all(fold_model)
        else:
            freeze_all_but_classifier(fold_model)
        fold_model.to(device)

        fold_train_ds = PP2020Dataset(
            fold_train_rows,
            images_dir,
            processor,
            max_items=0,
            augment=train_augment,
            use_soft_labels=False,
        )
        fold_val_ds = PP2020Dataset(
            fold_val_rows,
            images_dir,
            processor,
            max_items=0,
            augment=None,
            use_soft_labels=False,
        )

        fold_train_loader = DataLoader(
            fold_train_ds,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=torch.cuda.is_available(),
            collate_fn=collate_fn,
            drop_last=False,
        )
        fold_val_loader = DataLoader(
            fold_val_ds,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=torch.cuda.is_available(),
            collate_fn=collate_fn,
            drop_last=False,
        )

        log_cli(f"[kfold] fold={fold_idx}/{len(folds)} train={len(fold_train_rows)} val={len(fold_val_rows)}")
        train_simple(
            fold_model,
            fold_train_loader,
            args,
            device,
            amp_device,
            amp_dtype,
            use_grad_scaler,
        )
        fold_probs = predict_probs(
            fold_model,
            fold_val_loader,
            device,
            amp_device,
            amp_dtype,
        )
        oof_probs[val_idx] = fold_probs.numpy()
        del fold_model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return oof_probs


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--data-dir", type=str, default="./pp2020")
    ap.add_argument("--model-id", type=str, default="google/t5gemma-2-270m-270m")
    ap.add_argument("--classifier-head", type=str, choices=["linear", "mlp"], default="linear")
    ap.add_argument("--classifier-mlp-dim", type=int, default=0)
    ap.add_argument("--classifier-dropout", type=float, default=0.1)
    ap.add_argument("--epochs", type=int, default=5)
    ap.add_argument("--batch-size", type=int, default=8)
    ap.add_argument("--num-workers", type=int, default=4)
    ap.add_argument("--lr", type=float, default=3e-4)
    ap.add_argument("--weight-decay", type=float, default=0.0)
    ap.add_argument("--warmup-ratio", type=float, default=0.05)
    ap.add_argument("--grad-accum", type=int, default=1)
    ap.add_argument("--log-every", type=int, default=10)
    ap.add_argument("--eval-every", type=int, default=25)
    ap.add_argument("--augment", action="store_true")
    ap.add_argument("--tta-repeats", type=int, default=0)
    ap.add_argument("--estimate-gns", action="store_true")
    ap.add_argument("--gns-steps", type=int, default=5)
    ap.add_argument("--gns-batch-size", type=int, default=8)
    ap.add_argument("--check-duplicates", action="store_true")
    ap.add_argument("--remove-duplicates", action="store_true")
    ap.add_argument("--train-full-model", action="store_true")
    ap.add_argument("--kfolds", type=int, default=0)
    ap.add_argument("--label-mix-alpha", type=float, default=0.0)
    ap.add_argument("--dedupe-threshold", type=float, default=0.995)
    ap.add_argument("--dedupe-batch-size", type=int, default=124)
    ap.add_argument("--dedupe-max-pairs", type=int, default=20)
    ap.add_argument("--dedupe-chunk-size", type=int, default=256)
    ap.add_argument("--dedupe-output", type=str, default="duplicate_report.json")
    ap.add_argument("--seed", type=int, default=42)

    ap.add_argument("--val-split", type=float, default=1.0 / 3.0)
    ap.add_argument("--max-items", type=int, default=0)

    ap.add_argument("--project", type=str, default="pp2020-t5gemma2")
    ap.add_argument("--run-name", type=str, default="")
    ap.add_argument("--wandb-mode", type=str, choices=["online", "offline", "disabled"], default="online")
    ap.add_argument("--wandb-init-timeout", type=int, default=180)
    ap.add_argument("--use-fast-processor", action=argparse.BooleanOptionalAction, default=True)
    args = ap.parse_args()

    seed_everything(args.seed)

    # W&B
    if args.wandb_mode == "disabled":
        os.environ["WANDB_MODE"] = "disabled"
    else:
        os.environ["WANDB_MODE"] = args.wandb_mode

    wandb.init(
        project=args.project,
        name=args.run_name,
        settings=wandb.Settings(init_timeout=args.wandb_init_timeout),
        config={
            "dataset": "pp2020",
            "model_id": args.model_id,
            "classifier_head": args.classifier_head,
            "classifier_mlp_dim": args.classifier_mlp_dim,
            "classifier_dropout": args.classifier_dropout,
            "epochs": args.epochs,
            "batch_size": args.batch_size,
            "lr": args.lr,
            "weight_decay": args.weight_decay,
            "warmup_ratio": args.warmup_ratio,
            "grad_accum": args.grad_accum,
            "log_every": args.log_every,
            "eval_every": args.eval_every,
            "augment": args.augment,
            "tta_repeats": args.tta_repeats,
            "estimate_gns": args.estimate_gns,
            "gns_steps": args.gns_steps,
            "gns_batch_size": args.gns_batch_size,
            "check_duplicates": args.check_duplicates,
            "remove_duplicates": args.remove_duplicates,
            "train_full_model": args.train_full_model,
            "kfolds": args.kfolds,
            "label_mix_alpha": args.label_mix_alpha,
            "dedupe_threshold": args.dedupe_threshold,
            "dedupe_batch_size": args.dedupe_batch_size,
            "dedupe_max_pairs": args.dedupe_max_pairs,
            "dedupe_chunk_size": args.dedupe_chunk_size,
            "dedupe_output": args.dedupe_output,
            "seed": args.seed,
            "val_split": args.val_split,
            "max_items": args.max_items,
            "labels": LABEL_COLS,
            "wandb_init_timeout": args.wandb_init_timeout,
            "use_fast_processor": args.use_fast_processor,
        },
    )

    # Data directory should contain train.csv and images/
    root = Path(args.data_dir).resolve()

    train_csv = root / "train.csv"
    images_dir = root / "images"
    if not train_csv.exists():
        raise FileNotFoundError(f"Missing {train_csv}")
    if not images_dir.exists():
        raise FileNotFoundError(f"Missing {images_dir}")

    rows = read_train_csv(train_csv)

    train_augment = None
    if args.augment:
        train_augment = build_pp2020_augmentations()

    # Model + processor
    processor = load_processor(args.model_id, args.use_fast_processor)
    model = EncoderClassifier(
        model_id=args.model_id,
        num_labels=len(LABEL_COLS),
        classifier_dropout=args.classifier_dropout,
        classifier_head=args.classifier_head,
        classifier_hidden_dim=args.classifier_mlp_dim,
    )

    if args.train_full_model:
        trainable_names = unfreeze_all(model)
    else:
        trainable_names = freeze_all_but_classifier(model)
    wandb.config.update({"trainable_param_names": trainable_names}, allow_val_change=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    amp_device = "cuda" if torch.cuda.is_available() else "cpu"
    if torch.cuda.is_available():
        if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
            amp_dtype = torch.bfloat16
        else:
            amp_dtype = torch.float16
    else:
        amp_dtype = torch.float32
    use_grad_scaler = torch.cuda.is_available() and amp_dtype == torch.float16
    scaler = torch.amp.GradScaler(amp_device, enabled=use_grad_scaler)

    def log_cli(msg: str) -> None:
        print(msg, flush=True)

    if args.check_duplicates or args.remove_duplicates:
        embeddings, image_ids = compute_image_embeddings(
            rows,
            images_dir,
            processor,
            model,
            device,
            args.dedupe_batch_size,
            args.num_workers,
            amp_device,
            amp_dtype,
        )
        max_pairs = args.dedupe_max_pairs if args.check_duplicates else 0
        pairs, total_pairs, groups = scan_duplicate_pairs(
            embeddings,
            image_ids,
            args.dedupe_threshold,
            max_pairs,
            args.dedupe_chunk_size,
            build_groups=args.remove_duplicates,
        )
        report = {
            "threshold": args.dedupe_threshold,
            "checked_images": len(image_ids),
            "total_pairs": total_pairs,
            "pairs": pairs,
        }
        if args.remove_duplicates:
            keep = set()
            removed_ids: List[str] = []
            if groups:
                for group in groups.values():
                    group_sorted = sorted(group)
                    keep_idx = group_sorted[0]
                    keep.add(keep_idx)
                    for idx in group_sorted[1:]:
                        removed_ids.append(image_ids[idx])
            if removed_ids:
                rows = [rows[i] for i in range(len(rows)) if i in keep]
            report["removed_count"] = len(removed_ids)
            report["kept_count"] = len(rows)
            report["removed_image_ids"] = removed_ids
            log_cli(f"[dedupe] removed={len(removed_ids)} kept={len(rows)}")
            wandb.log({"dedupe/removed_count": len(removed_ids)}, step=0)
        Path(args.dedupe_output).write_text(json.dumps(report, indent=2), encoding="utf-8")
        wandb.log(
            {
                "dedupe/checked_images": len(image_ids),
                "dedupe/total_pairs": total_pairs,
                "dedupe/threshold": args.dedupe_threshold,
            },
            step=0,
        )
        log_cli(
            f"[dedupe] checked={len(image_ids)} pairs={total_pairs} report={args.dedupe_output}"
        )
        for pair in pairs:
            log_cli(
                f"[dedupe] {pair['image_id_a']} <-> {pair['image_id_b']} sim={pair['cosine_sim']:.4f}"
            )

    # Random holdout split (no group ids available in PP2020 package)
    idxs = np.arange(len(rows))
    np.random.shuffle(idxs)
    n_val = max(1, int(len(rows) * args.val_split))
    val_set = set(idxs[:n_val].tolist())
    train_rows = [rows[i] for i in range(len(rows)) if i not in val_set]
    val_rows = [rows[i] for i in range(len(rows)) if i in val_set]

    if args.max_items and args.max_items > 0:
        train_rows = train_rows[: args.max_items]
        val_rows = val_rows[: max(1, int(args.max_items * args.val_split))]

    if args.label_mix_alpha < 0.0 or args.label_mix_alpha > 1.0:
        raise ValueError("--label-mix-alpha must be in [0, 1].")
    if args.tta_repeats < 0:
        raise ValueError("--tta-repeats must be >= 0.")
    if args.classifier_mlp_dim < 0:
        raise ValueError("--classifier-mlp-dim must be >= 0.")
    if args.classifier_dropout < 0.0 or args.classifier_dropout > 1.0:
        raise ValueError("--classifier-dropout must be in [0, 1].")

    use_soft_labels = args.kfolds >= 2 and args.label_mix_alpha > 0.0
    if args.label_mix_alpha > 0.0 and args.kfolds < 2:
        raise ValueError("--label-mix-alpha requires --kfolds >= 2 for OOF predictions.")

    if use_soft_labels:
        oof_probs = build_oof_soft_labels(
            train_rows,
            images_dir,
            processor,
            args,
            device,
            amp_device,
            amp_dtype,
            use_grad_scaler,
            train_augment,
            log_cli,
        )
        for idx, row in enumerate(train_rows):
            y = np.array([float(row[c]) for c in LABEL_COLS], dtype=np.float32)
            soft = (1.0 - args.label_mix_alpha) * y + args.label_mix_alpha * oof_probs[idx]
            row["soft_labels"] = soft.tolist()
        log_cli(
            f"[kfold] mixed labels alpha={args.label_mix_alpha:.2f} folds={args.kfolds}"
        )

    # Data loaders
    train_ds = PP2020Dataset(
        train_rows,
        images_dir,
        processor,
        max_items=0,
        augment=train_augment,
        use_soft_labels=use_soft_labels,
    )
    val_ds = PP2020Dataset(val_rows, images_dir, processor, max_items=0, augment=None)

    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=torch.cuda.is_available(),
        collate_fn=collate_fn,
        drop_last=False,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=torch.cuda.is_available(),
        collate_fn=collate_fn,
        drop_last=False,
    )
    val_tta_loader = None
    if args.tta_repeats > 0:
        val_tta_ds = PP2020Dataset(
            val_rows,
            images_dir,
            processor,
            max_items=0,
            augment=build_pp2020_tta_augmentations(),
        )
        val_tta_loader = DataLoader(
            val_tta_ds,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=torch.cuda.is_available(),
            collate_fn=collate_fn,
            drop_last=False,
        )

    # Optim + schedule (head-only)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optim = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=args.weight_decay)

    steps_per_epoch = math.ceil(len(train_loader) / max(1, args.grad_accum))
    total_steps = steps_per_epoch * args.epochs
    warmup_steps = int(total_steps * args.warmup_ratio)
    sched = get_cosine_schedule_with_warmup(
        optim, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )

    if args.augment:
        log_cli("[augment] enabled=pp2020")
    if args.tta_repeats > 0:
        log_cli(f"[tta] enabled repeats={args.tta_repeats}")

    if args.estimate_gns:
        gns_loader = DataLoader(
            train_ds,
            batch_size=args.gns_batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            pin_memory=torch.cuda.is_available(),
            collate_fn=collate_fn,
            drop_last=True,
        )
        gns_mean, gns_std = estimate_gradient_noise_scale(
            model,
            gns_loader,
            device,
            args.gns_steps,
            amp_device,
            amp_dtype,
        )
        gns_opt = int(max(1.0, round(gns_mean))) if not math.isnan(gns_mean) else 0
        wandb.log(
            {
                "gns/scale": gns_mean,
                "gns/std": gns_std,
                "gns/batch_size": args.gns_batch_size,
                "gns/steps": args.gns_steps,
                "gns/optimal_batch_size": gns_opt,
            },
            step=0,
        )
        log_cli(f"[gns] scale={gns_mean:.4f} std={gns_std:.4f} opt_batch~{gns_opt}")

    # Train
    global_step = 0
    best_auc = -1.0
    best_path = Path("best_pp2020_t5gemma2_head.pt").resolve()
    last_eval_step = -1

    for epoch in range(1, args.epochs + 1):
        model.train()
        optim.zero_grad(set_to_none=True)
        running = 0.0

        for step, batch in enumerate(train_loader, start=1):
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}

            with torch.amp.autocast(
                amp_device,
                enabled=torch.cuda.is_available(),
                dtype=amp_dtype,
            ):
                out = model(**batch)
                loss = out.loss / max(1, args.grad_accum)

            if use_grad_scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            running += float(loss.detach().cpu())

            if step % args.grad_accum == 0:
                if use_grad_scaler:
                    scaler.step(optim)
                    scaler.update()
                else:
                    optim.step()
                optim.zero_grad(set_to_none=True)
                sched.step()

                global_step += 1
                lr = sched.get_last_lr()[0]
                wandb.log(
                    {
                        "train/loss": running,
                        "train/lr": lr,
                        "train/epoch": epoch,
                        "train/step": global_step,
                    },
                    step=global_step,
                )
                if args.log_every > 0 and global_step % args.log_every == 0:
                    log_cli(
                        f"[train] step={global_step} epoch={epoch} loss={running:.4f} lr={lr:.2e}"
                    )
                running = 0.0

                if args.eval_every > 0 and global_step % args.eval_every == 0:
                    metrics = evaluate(
                        model,
                        val_loader,
                        device,
                        tta_loader=val_tta_loader,
                        tta_repeats=args.tta_repeats,
                    )
                    metrics["val/epoch"] = epoch
                    metrics["val/step"] = global_step
                    wandb.log(metrics, step=global_step)
                    log_cli(
                        "[eval] "
                        f"step={global_step} loss={metrics.get('val/loss', float('nan')):.4f} "
                        f"f1={metrics.get('val/f1_macro', float('nan')):.4f} "
                        f"acc={metrics.get('val/accuracy', float('nan')):.4f} "
                        f"auc={metrics.get('val/mean_roc_auc', float('nan')):.4f}"
                    )

                    sel_auc = metrics.get("val/mean_roc_auc", float("nan"))
                    if not math.isnan(sel_auc) and sel_auc > best_auc:
                        best_auc = float(sel_auc)
                        torch.save(
                            {
                                "model_id": args.model_id,
                                "state_dict": model.state_dict(),
                                "labels": LABEL_COLS,
                                "processor_id": args.model_id,
                                "metrics": metrics,
                            },
                            best_path,
                        )
                        wandb.log({"val/best_mean_roc_auc": best_auc}, step=global_step)
                    model.train()
                    last_eval_step = global_step

        # Eval at epoch end if no recent eval ran on the last step.
        if last_eval_step != global_step:
            metrics = evaluate(
                model,
                val_loader,
                device,
                tta_loader=val_tta_loader,
                tta_repeats=args.tta_repeats,
            )
            metrics["val/epoch"] = epoch
            metrics["val/step"] = global_step
            wandb.log(metrics, step=global_step)
            log_cli(
                "[eval] "
                f"step={global_step} loss={metrics.get('val/loss', float('nan')):.4f} "
                f"f1={metrics.get('val/f1_macro', float('nan')):.4f} "
                f"acc={metrics.get('val/accuracy', float('nan')):.4f} "
                f"auc={metrics.get('val/mean_roc_auc', float('nan')):.4f}"
            )

            sel_auc = metrics.get("val/mean_roc_auc", float("nan"))
            if not math.isnan(sel_auc) and sel_auc > best_auc:
                best_auc = float(sel_auc)
                torch.save(
                    {
                        "model_id": args.model_id,
                        "state_dict": model.state_dict(),
                        "labels": LABEL_COLS,
                        "processor_id": args.model_id,
                        "metrics": metrics,
                    },
                    best_path,
                )
                wandb.log({"val/best_mean_roc_auc": best_auc}, step=global_step)
            last_eval_step = global_step

    # Log artifacts
    art = wandb.Artifact(name=f"{args.run_name}-best", type="model")
    art.add_file(str(best_path))
    wandb.log_artifact(art)

    manifest = {
        "dataset": "pp2020",
        "data_root": str(root),
        "train_csv": str(train_csv),
        "images_dir": str(images_dir),
        "model_id": args.model_id,
        "classifier_head": args.classifier_head,
        "classifier_mlp_dim": args.classifier_mlp_dim,
        "classifier_dropout": args.classifier_dropout,
        "labels": LABEL_COLS,
        "best_mean_roc_auc": best_auc,
    }
    Path("run_manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")
    wandb.save("run_manifest.json")


if __name__ == "__main__":
    main()

Here are the requirements:

--extra-index-url https://download.pytorch.org/whl/cu118

torch==2.5.1+cu118
torchvision==0.20.1+cu118

# T5Gemma2 support via transformers git (matches current setup)
transformers @ git+https://github.com/huggingface/transformers.git@ad7f4d0103599ff098bb33c11b9c1a73d97262fd
huggingface-hub==1.3.1

sentencepiece
protobuf
pillow
numpy
wandb
kaggle
scikit-learn
albumentations
opencv-python