@@ -1398,3 +1398,209 @@ def verify_image(self, image: Image) -> WatermarkVerificationResponse:
13981398 _prediction_response = response ,
13991399 watermark_verification_result = verification_likelihood ,
14001400 )
1401+
1402+
1403+ class Scribble :
1404+ """Input scribble for image segmentation."""
1405+
1406+ __module__ = "vertexai.preview.vision_models"
1407+
1408+ _image_ : Optional [Image ] = None
1409+
1410+ def __init__ (
1411+ self ,
1412+ image_bytes : Optional [bytes ],
1413+ gcs_uri : Optional [str ] = None ,
1414+ ):
1415+ """Creates a `Scribble` object.
1416+
1417+ Args:
1418+ image_bytes: Mask image file bytes.
1419+ gcs_uri: Mask image file Google Cloud Storage uri.
1420+ """
1421+ if bool (image_bytes ) == bool (gcs_uri ):
1422+ raise ValueError ("Either image_bytes or gcs_uri must be provided." )
1423+
1424+ self ._image_ = Image (image_bytes , gcs_uri )
1425+
1426+ @property
1427+ def image (self ) -> Optional [Image ]:
1428+ """The scribble image."""
1429+ return self ._image_
1430+
1431+
1432+ @dataclasses .dataclass
1433+ class EntityLabel :
1434+ """Entity label holding a text label and any associated confidence score."""
1435+
1436+ __module__ = "vertexai.preview.vision_models"
1437+
1438+ label : Optional [str ] = None
1439+ score : Optional [float ] = None
1440+
1441+
1442+ class GeneratedMask (Image ):
1443+ """Generated image mask."""
1444+
1445+ __module__ = "vertexai.preview.vision_models"
1446+
1447+ __labels__ : Optional [List [EntityLabel ]] = None
1448+
1449+ def __init__ (
1450+ self ,
1451+ image_bytes : Optional [bytes ],
1452+ gcs_uri : Optional [str ] = None ,
1453+ labels : Optional [List [EntityLabel ]] = None ,
1454+ ):
1455+ """Creates a `GeneratedMask` object.
1456+
1457+ Args:
1458+ image_bytes: Mask image file bytes.
1459+ gcs_uri: Mask image file Google Cloud Storage uri.
1460+ labels: Generated entity labels. Each text label might be associated
1461+ with a confidence score.
1462+ """
1463+
1464+ super ().__init__ (
1465+ image_bytes = image_bytes ,
1466+ gcs_uri = gcs_uri ,
1467+ )
1468+ self .__labels__ = labels
1469+
1470+ @property
1471+ def labels (self ) -> Optional [List [EntityLabel ]]:
1472+ """The entity labels of the masked object."""
1473+ return self .__labels__
1474+
1475+
1476+ @dataclasses .dataclass
1477+ class ImageSegmentationResponse :
1478+ """Image Segmentation response.
1479+
1480+ Attributes:
1481+ masks: The list of generated masks.
1482+ """
1483+
1484+ __module__ = "vertexai.preview.vision_models"
1485+
1486+ _prediction_response : Any
1487+ masks : List [GeneratedMask ]
1488+
1489+ def __iter__ (self ) -> typing .Iterator [GeneratedMask ]:
1490+ """Iterates through the generated masks."""
1491+ yield from self .masks
1492+
1493+ def __getitem__ (self , idx : int ) -> GeneratedMask :
1494+ """Gets the generated masks by index."""
1495+ return self .masks [idx ]
1496+
1497+
1498+ class ImageSegmentationModel (_model_garden_models ._ModelGardenModel ):
1499+ """Segments an image."""
1500+
1501+ __module__ = "vertexai.preview.vision_models"
1502+
1503+ _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/image_segmentation_model_1.0.0.yaml"
1504+
1505+ def segment_image (
1506+ self ,
1507+ base_image : Image ,
1508+ prompt : Optional [str ] = None ,
1509+ scribble : Optional [Scribble ] = None ,
1510+ mode : Literal [
1511+ "foreground" , "background" , "semantic" , "prompt" , "interactive"
1512+ ] = "foreground" ,
1513+ max_predictions : Optional [int ] = None ,
1514+ confidence_threshold : Optional [float ] = 0.1 ,
1515+ mask_dilation : Optional [float ] = None ,
1516+ ) -> ImageSegmentationResponse :
1517+ """Segments an image.
1518+
1519+ Args:
1520+ base_image: The base image to segment.
1521+ prompt: The prompt to guide the segmentation. Valid for the prompt and
1522+ semantic modes.
1523+ scribble: The scribble in the form of an image mask to guide the
1524+ segmentation. Valid for the interactive mode. The scribble image
1525+ should be a black-and-white PNG file equal in size to the base
1526+ image. White pixels represent the scribbled brush stroke which
1527+ select objects in the base image to segment.
1528+ mode: The segmentation mode. Supported values are:
1529+ * foreground: segment the foreground object of an image
1530+ * background: segment the background of an image
1531+ * semantic: specify the objects to segment with a comma delimited
1532+ list of objects from the class set in the prompt.
1533+ * prompt: use an open-vocabulary text prompt to select objects to
1534+ segment.
1535+ * interactive: draw scribbles with a brush stroke to guide the
1536+ segmentation. The default is foreground.
1537+ max_predictions: The maximum number of predictions to make. Valid for
1538+ the prompt mode. Default is unlimited.
1539+ confidence_threshold: A threshold to filter predictions by confidence
1540+ score. The value must be in the range of 0.0 and 1.0. The default is
1541+ 0.1.
1542+ mask_dilation: A value to dilate the masks by. The value must be in the
1543+ range of 0.0 (no dilation) and 1.0 (the whole image will be masked).
1544+ The default is 0.0.
1545+
1546+ Returns:
1547+ An `ImageSegmentationResponse` object with the generated masks,
1548+ entities, and labels (if any).
1549+ """
1550+ if not base_image :
1551+ raise ValueError ("Base image is required." )
1552+ instance = {}
1553+
1554+ if base_image ._gcs_uri :
1555+ instance ["image" ] = {"gcsUri" : base_image ._gcs_uri }
1556+ else :
1557+ instance ["image" ] = {"bytesBase64Encoded" : base_image ._as_base64_string ()}
1558+
1559+ if prompt :
1560+ instance ["prompt" ] = prompt
1561+
1562+ parameters = {}
1563+ if scribble and scribble .image :
1564+ scribble_image = scribble .image
1565+ if scribble_image ._gcs_uri :
1566+ instance ["scribble" ] = {"image" : {"gcsUri" : scribble_image ._gcs_uri }}
1567+ else :
1568+ instance ["scribble" ] = {
1569+ "image" : {"bytesBase64Encoded" : scribble_image ._as_base64_string ()}
1570+ }
1571+ parameters ["mode" ] = mode
1572+ if max_predictions :
1573+ parameters ["maxPredictions" ] = max_predictions
1574+ if confidence_threshold :
1575+ parameters ["confidenceThreshold" ] = confidence_threshold
1576+ if mask_dilation :
1577+ parameters ["maskDilation" ] = mask_dilation
1578+
1579+ response = self ._endpoint .predict (
1580+ instances = [instance ],
1581+ parameters = parameters ,
1582+ )
1583+
1584+ masks : List [GeneratedMask ] = []
1585+ for prediction in response .predictions :
1586+ encoded_bytes = prediction .get ("bytesBase64Encoded" )
1587+ labels = []
1588+ if "labels" in prediction :
1589+ for label in prediction ["labels" ]:
1590+ labels .append (
1591+ EntityLabel (
1592+ label = label .get ("label" ),
1593+ score = label .get ("score" ),
1594+ )
1595+ )
1596+ generated_image = GeneratedMask (
1597+ image_bytes = base64 .b64decode (encoded_bytes ) if encoded_bytes else None ,
1598+ gcs_uri = prediction .get ("gcsUri" ),
1599+ labels = labels ,
1600+ )
1601+ masks .append (generated_image )
1602+
1603+ return ImageSegmentationResponse (
1604+ _prediction_response = response ,
1605+ masks = masks ,
1606+ )
0 commit comments