#13619Classification model export2tflite model failed

Issue Details

4 months ago
No assignee
enhancementclassifyexports
jshh0401jshh0401
opened 4 months ago
Author

Search before asking

  • I have searched the YOLOv5 issues and found no similar feature requests.

Description

Because the classification model is created in a strange way, it is based on the detection model, deletes the detect head, and then replaces it with the classification head. The model is not created directly from yaml. This causes the classification model to fail when exported to tflite. I think this change is not difficult, but it is very useful. It can help users share YOLOv5 code modules and backbones to do image classification tasks.

ref: https://github.com/ultralytics/yolov5/blob/master/models/yolo.py (345-375) https://github.com/ultralytics/yolov5/blob/master/export.py (1431)

assert not isinstance(model, ClassificationModel), "ClassificationModel export to TF formats not yet supported."

`class ClassificationModel(BaseModel): """YOLOv5 classification model for image classification tasks, initialized with a config file or detection model."""

def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): """Initializes YOLOv5 model with config file `cfg`, input channels `ch`, number of classes `nc`, and `cuttoff` index. """ super().__init__() self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg) def _from_detection_model(self, model, nc=1000, cutoff=10): """Creates a classification model from a YOLOv5 detection model, slicing at `cutoff` and adding a classification layer. """ if isinstance(model, DetectMultiBackend): model = model.model # unwrap DetectMultiBackend model.model = model.model[:cutoff] # backbone m = model.model[-1] # last layer ch = m.conv.in_channels if hasattr(m, "conv") else m.cv1.conv.in_channels # ch into module c = Classify(ch, nc) # Classify() c.i, c.f, c.type = m.i, m.f, "models.common.Classify" # index, from, type model.model[-1] = c # replace self.model = model.model self.stride = model.stride self.save = [] self.nc = nc def _from_yaml(self, cfg): """Creates a YOLOv5 classification model from a specified *.yaml configuration file.""" self.model = None`

Use case

No response

Additional

No response

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!