Skip to content

vllm.multimodal.media

Modules:

Name Description
audio
base
image
video

__all__ module-attribute

__all__ = [
    "MediaIO",
    "MediaWithBytes",
    "AudioEmbeddingMediaIO",
    "AudioMediaIO",
    "ImageEmbeddingMediaIO",
    "ImageMediaIO",
    "VideoMediaIO",
]

AudioEmbeddingMediaIO

Bases: MediaIO[Tensor]

Source code in vllm/multimodal/media/audio.py
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
    def __init__(self) -> None:
        super().__init__()

    def load_bytes(self, data: bytes) -> torch.Tensor:
        buffer = BytesIO(data)
        # Enable sparse tensor integrity checks to prevent out-of-bounds
        # writes from maliciously crafted tensors
        with torch.sparse.check_sparse_tensor_invariants():
            tensor = torch.load(buffer, weights_only=True)
            return tensor.to_dense()

    def load_base64(self, media_type: str, data: str) -> torch.Tensor:
        return self.load_bytes(pybase64.b64decode(data, validate=True))

    def load_file(self, filepath: Path) -> torch.Tensor:
        # Enable sparse tensor integrity checks to prevent out-of-bounds
        # writes from maliciously crafted tensors
        with torch.sparse.check_sparse_tensor_invariants():
            tensor = torch.load(filepath, weights_only=True)
            return tensor.to_dense()

    def encode_base64(self, media: torch.Tensor) -> str:
        return tensor2base64(media)

__init__

__init__() -> None
Source code in vllm/multimodal/media/audio.py
def __init__(self) -> None:
    super().__init__()

encode_base64

encode_base64(media: Tensor) -> str
Source code in vllm/multimodal/media/audio.py
def encode_base64(self, media: torch.Tensor) -> str:
    return tensor2base64(media)

load_base64

load_base64(media_type: str, data: str) -> Tensor
Source code in vllm/multimodal/media/audio.py
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
    return self.load_bytes(pybase64.b64decode(data, validate=True))

load_bytes

load_bytes(data: bytes) -> Tensor
Source code in vllm/multimodal/media/audio.py
def load_bytes(self, data: bytes) -> torch.Tensor:
    buffer = BytesIO(data)
    # Enable sparse tensor integrity checks to prevent out-of-bounds
    # writes from maliciously crafted tensors
    with torch.sparse.check_sparse_tensor_invariants():
        tensor = torch.load(buffer, weights_only=True)
        return tensor.to_dense()

load_file

load_file(filepath: Path) -> Tensor
Source code in vllm/multimodal/media/audio.py
def load_file(self, filepath: Path) -> torch.Tensor:
    # Enable sparse tensor integrity checks to prevent out-of-bounds
    # writes from maliciously crafted tensors
    with torch.sparse.check_sparse_tensor_invariants():
        tensor = torch.load(filepath, weights_only=True)
        return tensor.to_dense()

AudioMediaIO

Bases: MediaIO[tuple[NDArray, float]]

Source code in vllm/multimodal/media/audio.py
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
    def __init__(self, **kwargs) -> None:
        super().__init__()

        # `kwargs` contains custom arguments from
        # --media-io-kwargs for this modality.
        # They can be passed to the underlying
        # media loaders (e.g. custom implementations)
        # for flexible control.
        self.kwargs = kwargs

    def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
        return librosa.load(BytesIO(data), sr=None)

    def load_base64(
        self,
        media_type: str,
        data: str,
    ) -> tuple[npt.NDArray, float]:
        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
        return librosa.load(filepath, sr=None)

    def encode_base64(
        self,
        media: tuple[npt.NDArray, int],
        *,
        audio_format: str = "WAV",
    ) -> str:
        audio, sr = media

        with BytesIO() as buffer:
            soundfile.write(buffer, audio, sr, format=audio_format)
            data = buffer.getvalue()

        return base64.b64encode(data).decode("utf-8")

kwargs instance-attribute

kwargs = kwargs

__init__

__init__(**kwargs) -> None
Source code in vllm/multimodal/media/audio.py
def __init__(self, **kwargs) -> None:
    super().__init__()

    # `kwargs` contains custom arguments from
    # --media-io-kwargs for this modality.
    # They can be passed to the underlying
    # media loaders (e.g. custom implementations)
    # for flexible control.
    self.kwargs = kwargs

encode_base64

encode_base64(
    media: tuple[NDArray, int], *, audio_format: str = "WAV"
) -> str
Source code in vllm/multimodal/media/audio.py
def encode_base64(
    self,
    media: tuple[npt.NDArray, int],
    *,
    audio_format: str = "WAV",
) -> str:
    audio, sr = media

    with BytesIO() as buffer:
        soundfile.write(buffer, audio, sr, format=audio_format)
        data = buffer.getvalue()

    return base64.b64encode(data).decode("utf-8")

load_base64

load_base64(
    media_type: str, data: str
) -> tuple[NDArray, float]
Source code in vllm/multimodal/media/audio.py
def load_base64(
    self,
    media_type: str,
    data: str,
) -> tuple[npt.NDArray, float]:
    return self.load_bytes(base64.b64decode(data))

load_bytes

load_bytes(data: bytes) -> tuple[NDArray, float]
Source code in vllm/multimodal/media/audio.py
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
    return librosa.load(BytesIO(data), sr=None)

load_file

load_file(filepath: Path) -> tuple[NDArray, float]
Source code in vllm/multimodal/media/audio.py
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
    return librosa.load(filepath, sr=None)

ImageEmbeddingMediaIO

Bases: MediaIO[Tensor]

Source code in vllm/multimodal/media/image.py
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
    def __init__(self) -> None:
        super().__init__()

    def load_bytes(self, data: bytes) -> torch.Tensor:
        buffer = BytesIO(data)
        # Enable sparse tensor integrity checks to prevent out-of-bounds
        # writes from maliciously crafted tensors
        with torch.sparse.check_sparse_tensor_invariants():
            tensor = torch.load(buffer, weights_only=True)
            return tensor.to_dense()

    def load_base64(self, media_type: str, data: str) -> torch.Tensor:
        return self.load_bytes(pybase64.b64decode(data, validate=True))

    def load_file(self, filepath: Path) -> torch.Tensor:
        # Enable sparse tensor integrity checks to prevent out-of-bounds
        # writes from maliciously crafted tensors
        with torch.sparse.check_sparse_tensor_invariants():
            tensor = torch.load(filepath, weights_only=True)
            return tensor.to_dense()

    def encode_base64(self, media: torch.Tensor) -> str:
        return pybase64.b64encode(media.numpy()).decode("utf-8")

__init__

__init__() -> None
Source code in vllm/multimodal/media/image.py
def __init__(self) -> None:
    super().__init__()

encode_base64

encode_base64(media: Tensor) -> str
Source code in vllm/multimodal/media/image.py
def encode_base64(self, media: torch.Tensor) -> str:
    return pybase64.b64encode(media.numpy()).decode("utf-8")

load_base64

load_base64(media_type: str, data: str) -> Tensor
Source code in vllm/multimodal/media/image.py
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
    return self.load_bytes(pybase64.b64decode(data, validate=True))

load_bytes

load_bytes(data: bytes) -> Tensor
Source code in vllm/multimodal/media/image.py
def load_bytes(self, data: bytes) -> torch.Tensor:
    buffer = BytesIO(data)
    # Enable sparse tensor integrity checks to prevent out-of-bounds
    # writes from maliciously crafted tensors
    with torch.sparse.check_sparse_tensor_invariants():
        tensor = torch.load(buffer, weights_only=True)
        return tensor.to_dense()

load_file

load_file(filepath: Path) -> Tensor
Source code in vllm/multimodal/media/image.py
def load_file(self, filepath: Path) -> torch.Tensor:
    # Enable sparse tensor integrity checks to prevent out-of-bounds
    # writes from maliciously crafted tensors
    with torch.sparse.check_sparse_tensor_invariants():
        tensor = torch.load(filepath, weights_only=True)
        return tensor.to_dense()

ImageMediaIO

Bases: MediaIO[Image]

Source code in vllm/multimodal/media/image.py
class ImageMediaIO(MediaIO[Image.Image]):
    def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
        super().__init__()

        self.image_mode = image_mode
        # `kwargs` contains custom arguments from
        # --media-io-kwargs for this modality.
        # They can be passed to the underlying
        # media loaders (e.g. custom implementations)
        # for flexible control.
        self.kwargs = kwargs

        # Extract RGBA background color from kwargs if provided
        # Default to white background for backward compatibility
        rgba_bg = kwargs.get("rgba_background_color", (255, 255, 255))
        # Convert list to tuple for consistency
        if isinstance(rgba_bg, list):
            rgba_bg = tuple(rgba_bg)

        # Validate rgba_background_color format
        if not (
            isinstance(rgba_bg, tuple)
            and len(rgba_bg) == 3
            and all(isinstance(c, int) and 0 <= c <= 255 for c in rgba_bg)
        ):
            raise ValueError(
                "rgba_background_color must be a list or tuple of 3 integers "
                "in the range [0, 255]."
            )
        self.rgba_background_color = rgba_bg

    def _convert_image_mode(
        self, image: Image.Image | MediaWithBytes[Image.Image]
    ) -> Image.Image:
        """Convert image mode with custom background color."""
        if isinstance(image, MediaWithBytes):
            image = image.media
        if image.mode == self.image_mode:
            return image
        elif image.mode == "RGBA" and self.image_mode == "RGB":
            return rgba_to_rgb(image, self.rgba_background_color)
        else:
            return convert_image_mode(image, self.image_mode)

    def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
        image = Image.open(BytesIO(data))
        return MediaWithBytes(self._convert_image_mode(image), data)

    def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
        return self.load_bytes(pybase64.b64decode(data, validate=True))

    def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
        with open(filepath, "rb") as f:
            data = f.read()
        image = Image.open(BytesIO(data))
        return MediaWithBytes(self._convert_image_mode(image), data)

    def encode_base64(
        self,
        media: Image.Image,
        *,
        image_format: str | None = None,
    ) -> str:
        if image_format is None:
            logger.warning_once(
                "The default format of `ImageMediaIO.encode_base64` will be changed "
                'from "JPEG" to "PNG" in v0.15 to avoid lossy compression. '
                "To continue using the old default, "
                'pass `format="JPEG"` explicitly to silence this warning.'
            )
            image_format = "JPEG"

        image = media

        with BytesIO() as buffer:
            image = self._convert_image_mode(image)
            image.save(buffer, image_format)
            data = buffer.getvalue()

        return pybase64.b64encode(data).decode("utf-8")

image_mode instance-attribute

image_mode = image_mode

kwargs instance-attribute

kwargs = kwargs

rgba_background_color instance-attribute

rgba_background_color = rgba_bg

__init__

__init__(image_mode: str = 'RGB', **kwargs) -> None
Source code in vllm/multimodal/media/image.py
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
    super().__init__()

    self.image_mode = image_mode
    # `kwargs` contains custom arguments from
    # --media-io-kwargs for this modality.
    # They can be passed to the underlying
    # media loaders (e.g. custom implementations)
    # for flexible control.
    self.kwargs = kwargs

    # Extract RGBA background color from kwargs if provided
    # Default to white background for backward compatibility
    rgba_bg = kwargs.get("rgba_background_color", (255, 255, 255))
    # Convert list to tuple for consistency
    if isinstance(rgba_bg, list):
        rgba_bg = tuple(rgba_bg)

    # Validate rgba_background_color format
    if not (
        isinstance(rgba_bg, tuple)
        and len(rgba_bg) == 3
        and all(isinstance(c, int) and 0 <= c <= 255 for c in rgba_bg)
    ):
        raise ValueError(
            "rgba_background_color must be a list or tuple of 3 integers "
            "in the range [0, 255]."
        )
    self.rgba_background_color = rgba_bg

_convert_image_mode

_convert_image_mode(
    image: Image | MediaWithBytes[Image],
) -> Image

Convert image mode with custom background color.

Source code in vllm/multimodal/media/image.py
def _convert_image_mode(
    self, image: Image.Image | MediaWithBytes[Image.Image]
) -> Image.Image:
    """Convert image mode with custom background color."""
    if isinstance(image, MediaWithBytes):
        image = image.media
    if image.mode == self.image_mode:
        return image
    elif image.mode == "RGBA" and self.image_mode == "RGB":
        return rgba_to_rgb(image, self.rgba_background_color)
    else:
        return convert_image_mode(image, self.image_mode)

encode_base64

encode_base64(
    media: Image, *, image_format: str | None = None
) -> str
Source code in vllm/multimodal/media/image.py
def encode_base64(
    self,
    media: Image.Image,
    *,
    image_format: str | None = None,
) -> str:
    if image_format is None:
        logger.warning_once(
            "The default format of `ImageMediaIO.encode_base64` will be changed "
            'from "JPEG" to "PNG" in v0.15 to avoid lossy compression. '
            "To continue using the old default, "
            'pass `format="JPEG"` explicitly to silence this warning.'
        )
        image_format = "JPEG"

    image = media

    with BytesIO() as buffer:
        image = self._convert_image_mode(image)
        image.save(buffer, image_format)
        data = buffer.getvalue()

    return pybase64.b64encode(data).decode("utf-8")

load_base64

load_base64(
    media_type: str, data: str
) -> MediaWithBytes[Image]
Source code in vllm/multimodal/media/image.py
def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
    return self.load_bytes(pybase64.b64decode(data, validate=True))

load_bytes

load_bytes(data: bytes) -> MediaWithBytes[Image]
Source code in vllm/multimodal/media/image.py
def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
    image = Image.open(BytesIO(data))
    return MediaWithBytes(self._convert_image_mode(image), data)

load_file

load_file(filepath: Path) -> MediaWithBytes[Image]
Source code in vllm/multimodal/media/image.py
def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
    with open(filepath, "rb") as f:
        data = f.read()
    image = Image.open(BytesIO(data))
    return MediaWithBytes(self._convert_image_mode(image), data)

MediaIO

Bases: ABC, Generic[_T]

Source code in vllm/multimodal/media/base.py
class MediaIO(ABC, Generic[_T]):
    @abstractmethod
    def load_bytes(self, data: bytes) -> _T:
        raise NotImplementedError

    @abstractmethod
    def load_base64(self, media_type: str, data: str) -> _T:
        """
        List of media types:
        https://www.iana.org/assignments/media-types/media-types.xhtml
        """
        raise NotImplementedError

    @abstractmethod
    def load_file(self, filepath: Path) -> _T:
        raise NotImplementedError

load_base64 abstractmethod

load_base64(media_type: str, data: str) -> _T

List of media types: https://www.iana.org/assignments/media-types/media-types.xhtml

Source code in vllm/multimodal/media/base.py
@abstractmethod
def load_base64(self, media_type: str, data: str) -> _T:
    """
    List of media types:
    https://www.iana.org/assignments/media-types/media-types.xhtml
    """
    raise NotImplementedError

load_bytes abstractmethod

load_bytes(data: bytes) -> _T
Source code in vllm/multimodal/media/base.py
@abstractmethod
def load_bytes(self, data: bytes) -> _T:
    raise NotImplementedError

load_file abstractmethod

load_file(filepath: Path) -> _T
Source code in vllm/multimodal/media/base.py
@abstractmethod
def load_file(self, filepath: Path) -> _T:
    raise NotImplementedError

MediaWithBytes dataclass

Bases: Generic[_T]

Wrapper that couples a media object with its original encoded bytes.

This ensures the raw bytes and media object remain synchronized, preventing cache corruption from in-place modifications.

The wrapper delegates attribute access to the underlying media object, making it behave transparently like the wrapped type (e.g., PIL.Image).

NOTE: Currently, this wrapper is used only for the image modality.

Source code in vllm/multimodal/media/base.py
@dataclass
class MediaWithBytes(Generic[_T]):
    """
    Wrapper that couples a media object with its original encoded bytes.

    This ensures the raw bytes and media object remain synchronized,
    preventing cache corruption from in-place modifications.

    The wrapper delegates attribute access to the underlying media object,
    making it behave transparently like the wrapped type (e.g., PIL.Image).

    NOTE: Currently, this wrapper is used only for the image modality.
    """

    media: _T
    original_bytes: bytes

    def __array__(self, *args, **kwargs) -> np.ndarray:
        """Allow np.array(obj) to return np.array(obj.media)."""
        return np.array(self.media, *args, **kwargs)

    def __getstate__(self):
        return self.__dict__.copy()

    def __setstate__(self, state: dict[str, Any]):
        self.__dict__.update(state)

    def __getattr__(self, name: str):
        """Delegate attribute access to the underlying media object."""
        return getattr(self.media, name)

media instance-attribute

media: _T

original_bytes instance-attribute

original_bytes: bytes

__array__

__array__(*args, **kwargs) -> ndarray

Allow np.array(obj) to return np.array(obj.media).

Source code in vllm/multimodal/media/base.py
def __array__(self, *args, **kwargs) -> np.ndarray:
    """Allow np.array(obj) to return np.array(obj.media)."""
    return np.array(self.media, *args, **kwargs)

__getattr__

__getattr__(name: str)

Delegate attribute access to the underlying media object.

Source code in vllm/multimodal/media/base.py
def __getattr__(self, name: str):
    """Delegate attribute access to the underlying media object."""
    return getattr(self.media, name)

__getstate__

__getstate__()
Source code in vllm/multimodal/media/base.py
def __getstate__(self):
    return self.__dict__.copy()

__init__

__init__(media: _T, original_bytes: bytes) -> None

__setstate__

__setstate__(state: dict[str, Any])
Source code in vllm/multimodal/media/base.py
def __setstate__(self, state: dict[str, Any]):
    self.__dict__.update(state)

VideoMediaIO

Bases: MediaIO[tuple[NDArray, dict[str, Any]]]

Source code in vllm/multimodal/media/video.py
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
    def __init__(
        self,
        image_io: ImageMediaIO,
        num_frames: int = 32,
        **kwargs,
    ) -> None:
        super().__init__()

        self.image_io = image_io
        self.num_frames = num_frames
        # `kwargs` contains custom arguments from
        # --media-io-kwargs for this modality.
        # They can be passed to the underlying
        # media loaders (e.g. custom implementations)
        # for flexible control.

        # Allow per-request override of video backend via kwargs.
        # This enables users to specify a different backend than the
        # global VLLM_VIDEO_LOADER_BACKEND env var, e.g.:
        #   --media-io-kwargs '{"video": {"video_backend": "torchcodec"}}'
        video_loader_backend = (
            kwargs.pop("video_backend", None) or envs.VLLM_VIDEO_LOADER_BACKEND
        )
        self.kwargs = kwargs
        self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)

    def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
        return self.video_loader.load_bytes(
            data, num_frames=self.num_frames, **self.kwargs
        )

    def load_base64(
        self, media_type: str, data: str
    ) -> tuple[npt.NDArray, dict[str, Any]]:
        if media_type.lower() == "video/jpeg":
            load_frame = partial(
                self.image_io.load_base64,
                "image/jpeg",
            )

            return np.stack(
                [np.asarray(load_frame(frame_data)) for frame_data in data.split(",")]
            ), {}

        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
        with filepath.open("rb") as f:
            data = f.read()

        return self.load_bytes(data)

    def encode_base64(
        self,
        media: npt.NDArray,
        *,
        video_format: str = "JPEG",
    ) -> str:
        video = media

        if video_format == "JPEG":
            encode_frame = partial(
                self.image_io.encode_base64,
                image_format=video_format,
            )

            return ",".join(encode_frame(Image.fromarray(frame)) for frame in video)

        msg = "Only JPEG format is supported for now."
        raise NotImplementedError(msg)

image_io instance-attribute

image_io = image_io

kwargs instance-attribute

kwargs = kwargs

num_frames instance-attribute

num_frames = num_frames

video_loader instance-attribute

video_loader = load(video_loader_backend)

__init__

__init__(
    image_io: ImageMediaIO, num_frames: int = 32, **kwargs
) -> None
Source code in vllm/multimodal/media/video.py
def __init__(
    self,
    image_io: ImageMediaIO,
    num_frames: int = 32,
    **kwargs,
) -> None:
    super().__init__()

    self.image_io = image_io
    self.num_frames = num_frames
    # `kwargs` contains custom arguments from
    # --media-io-kwargs for this modality.
    # They can be passed to the underlying
    # media loaders (e.g. custom implementations)
    # for flexible control.

    # Allow per-request override of video backend via kwargs.
    # This enables users to specify a different backend than the
    # global VLLM_VIDEO_LOADER_BACKEND env var, e.g.:
    #   --media-io-kwargs '{"video": {"video_backend": "torchcodec"}}'
    video_loader_backend = (
        kwargs.pop("video_backend", None) or envs.VLLM_VIDEO_LOADER_BACKEND
    )
    self.kwargs = kwargs
    self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)

encode_base64

encode_base64(
    media: NDArray, *, video_format: str = "JPEG"
) -> str
Source code in vllm/multimodal/media/video.py
def encode_base64(
    self,
    media: npt.NDArray,
    *,
    video_format: str = "JPEG",
) -> str:
    video = media

    if video_format == "JPEG":
        encode_frame = partial(
            self.image_io.encode_base64,
            image_format=video_format,
        )

        return ",".join(encode_frame(Image.fromarray(frame)) for frame in video)

    msg = "Only JPEG format is supported for now."
    raise NotImplementedError(msg)

load_base64

load_base64(
    media_type: str, data: str
) -> tuple[NDArray, dict[str, Any]]
Source code in vllm/multimodal/media/video.py
def load_base64(
    self, media_type: str, data: str
) -> tuple[npt.NDArray, dict[str, Any]]:
    if media_type.lower() == "video/jpeg":
        load_frame = partial(
            self.image_io.load_base64,
            "image/jpeg",
        )

        return np.stack(
            [np.asarray(load_frame(frame_data)) for frame_data in data.split(",")]
        ), {}

    return self.load_bytes(base64.b64decode(data))

load_bytes

load_bytes(data: bytes) -> tuple[NDArray, dict[str, Any]]
Source code in vllm/multimodal/media/video.py
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
    return self.video_loader.load_bytes(
        data, num_frames=self.num_frames, **self.kwargs
    )

load_file

load_file(filepath: Path) -> tuple[NDArray, dict[str, Any]]
Source code in vllm/multimodal/media/video.py
def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
    with filepath.open("rb") as f:
        data = f.read()

    return self.load_bytes(data)