admin管理员组

文章数量:1429850

I'm trying to train my EfficientdetD0 model, but the loss function values are in the billions.

My code:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CocoDetection
from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import transforms as T
from pytorch_lightning.loggers import CSVLogger

class EfficientDetDataModule(pl.LightningDataModule):
    def __init__(self, train_dir, val_dir, train_ann, val_ann, batch_size=2, num_workers=4):
        super().__init__()
        self.train_dir = train_dir
        self.val_dir = val_dir
        self.train_ann = train_ann
        self.val_ann = val_ann
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.coco_train = CocoDetection(
            root=self.train_dir, 
            annFile=self.train_ann, 
            transform=self.train_transforms()
        )
        self.coco_val = CocoDetection(
            root=self.val_dir, 
            annFile=self.val_ann, 
            transform=self.val_transforms()
        )

    def train_transforms(self):
        return T.Compose([
            T.Resize((512, 512)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def val_transforms(self):
        return T.Compose([
            T.Resize((512, 512)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def collate_fn(self, batch):
        images, targets = zip(*batch)
        images = torch.stack(images)

        bboxes = []
        classes = []
        img_scales = []
        img_sizes = []

        for target_per_image in targets:
            bboxes_per_image = []
            classes_per_image = []

            for obj in target_per_image:
                bbox_tensor = torch.tensor(obj['bbox']).float()
                class_tensor = torch.tensor([obj['category_id']]).long()

                x_min, y_min, width, height = bbox_tensor
                x_max = x_min + width
                y_max = y_min + height

                if x_max <= x_min or y_max <= y_min:
                    print(f"Skipping invalid bounding box: {bbox_tensor}")
                    continue

                corrected_bbox = torch.tensor([x_min, y_min, x_max, y_max]).float()
                bboxes_per_image.append(corrected_bbox.unsqueeze(0))
                classes_per_image.append(class_tensor)

            if bboxes_per_image:
                bboxes.append(torch.cat(bboxes_per_image, dim=0))
                classes.append(torch.cat(classes_per_image, dim=0))
            else:
                bboxes.append(torch.empty((0, 4), dtype=torch.float32))
                classes.append(torch.empty((0,), dtype=torch.long))

            img_scales.append(1.0)
            img_sizes.append([images.shape[2], images.shape[3]])

        batch_targets = {
            'bbox': bboxes,
            'cls': classes,
            'img_scale': torch.tensor(img_scales, device=images.device),
            'img_size': torch.tensor(img_sizes, dtype=torch.float32, device=images.device),
        }

        return images, batch_targets

    def train_dataloader(self):
        return DataLoader(
            self.coco_train, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers, 
            collate_fn=self.collate_fn
        )

    def val_dataloader(self):
        return DataLoader(
            self.coco_val, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers, 
            collate_fn=self.collate_fn
        )

class EfficientDetModel(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()
        self.save_hyperparameters()
        config = get_efficientdet_config('tf_efficientdet_d0')
        self.model = EfficientDet(config, pretrained_backbone=True)
        self.model.class_net.num_classes = num_classes
        self.model = DetBenchTrain(self.model, config)

    def forward(self, images, targets=None):
        return self.model(images, targets)

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        targets['img_scale'] = targets['img_scale'].to(images.device)
        targets['img_size'] = targets['img_size'].to(images.device)

        output = self(images, targets)
        loss = output['loss']
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        images, targets = batch
        targets['img_scale'] = targets['img_scale'].to(images.device)
        targets['img_size'] = targets['img_size'].to(images.device)

        output = self(images, targets)

        loss = output['loss']
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Images: {images}")
            print(f"Targets: {targets}")
            print(f"Model output: {output}")
            raise ValueError(f"Loss contains NaN or Inf values: {loss}")

        self.log('train_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
        return optimizer

if __name__ == '__main__':
    train_dir = 'efficientdet/train2017'
    val_dir = 'efficientdet/valid2017'
    train_ann = 'efficientdet/annotations/instances_train2017.json'
    val_ann = 'efficientdet/annotations/instances_val2017.json'
    batch_size = 4
    num_classes = 14

    data_module = EfficientDetDataModule(train_dir, val_dir, train_ann, val_ann, batch_size=batch_size)
    data_module.setup()

    model = EfficientDetModel(num_classes=num_classes)

    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath="checkpoints",
        filename="efficientdet-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
    )

    trainer = pl.Trainer(
        max_epochs=50,
        devices=1,
        accelerator="gpu",
        callbacks=[checkpoint_callback],
        gradient_clip_val=1.0,
        logger=CSVLogger("logs", name="efficientdet"),
    )
    trainer.fit(model, data_module)

    best_model_path = checkpoint_callback.best_model_path
    if best_model_path:
        trained_model = EfficientDetModel.load_from_checkpoint(
            checkpoint_path=best_model_path,
            num_classes=num_classes,
            map_location="cuda",
        )

        print(f"Model has been loaded from {best_model_path}")

Loss function values:

epoch,step,train_loss
0,49,5900649472.0
0,99,6577373184.0
0,149,7079398400.0
0,199,6111747072.0
0,249,6603147776.0
0,299,4403147264.0
0,349,6613146624.0
0,399,6705645568.0
0,449,6798145536.0
0,499,4889868800.0

I tried changing lr (1e-4, 1e-5, 1e-6), I tried changing batchsize (2, 4, 8). I also checked my dataset for incorrect bbox sizes (outside the image and negative values). I also tried changing the input image size (256, 512, 640). None of this helped to find the problem and I am asking for your help.

Here's a link to my dataset:

本文标签: pythonEfficientdetD0 The huge values of the loss functionStack Overflow