Mouse interactive prompt engineering
One new method based on Segment(ation) Anything SAM
Start
Continuing the discussion on prompt engineering from last time, let’s talk about fast implementation for object segmentation in Computer Vision. This method was quite difficult before the emergence of SAM, even requiring separate training of image segmentation models for different targets. I won’t go into detail here.
SAM
The SAM mentioned here is a new image segmentation method proposed by Facebook in earlier 2023 and its variant HQ-SAM. Please refer to the following for more information:
Notice
This article mentions the “coordinate” prompt and will introduce the “bounding-box” and “text” prompts in the future when there is time. In fact, this article is just a story about the SAM API. For readers familiar with SAM, consider it as a review and maybe understand some things.
SAM API
predict_torch
The model provides many APIs, and any API ending with “_torch” can process the input prompt in batches.
The function’s signature is
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
Official explanation
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor or None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
My explanation
“point_coords” provides a list of coordinate groups, where each coordinate group is also a list (i.e., a list of lists).
“point_labels” provides a list of label groups, where each label group is also a list.
In other words, as mentioned earlier, “in batches” means that point_coords or point_labels can contain more than one group, and each group will apply a mask.
1 indicates a foreground point and 0 indicates a
background point.
My explanation is that the mask generated by the API will be a matrix of 1s or 0s, where the parts with 1s represent “retained” and the parts with 0s represent “excluded”.
Pseudocode
point_coords: [[ [x1, y1], [x2, y2], [x3, y3], [x4, y4] ...]... ]
point_labels: [[ 1, 0, 0, 0 ...]... ]
In other words, as mentioned earlier, “in batches” means that point_coords or point_labels can contain more than one group, and each group will apply a mask.
From a programming perspective, we can pass both the point_coords and point_labels to the API and receive a mask list or a batch of masks, with a shape of (batch, C, W, H). Each element in the batch list corresponds to the “masks” for a “coordinate group” (typically, C is 3, representing 3 masks). Usually, we choose the “most reliable” mask for third-party use.
How to select the “most reliable” mask?
The API will provide a mask score batch:
(torch.Tensor): An array of shape BxC containing the model's
predictions for the quality of each mask.
In my own words, each “unit” corresponds to the “scores” of the “masks” generated by ONE “coordinate group”. These scores should be a list, as mentioned earlier, typically generating three masks, so the number of scores should also be three. Judging from the signature, it is good, it is also “C”.
Again, each “coordinate group” generates pieces of masks (colored rectangles below), the scores-batch we see for each unit is a list of the quality of each mask.
Use argmax to determine which best mask each batch unit wants to choose.
After processing, we hope to obtain a mask with shape: (batch, 1, W, H).
Repeating: Before using score selection, the API returns a batch with shape (batch, C, W, H). After determining the best mask for each “element” in the batch on the score (shape: (batch, C)), we obtain the best mask-batch with shape: (batch, 1, W, H).
Note that the content of the mask, which is the “0”/”1" values, background-label, and foreground-label, occupies one dimension. This is because, as mentioned earlier, the API returns a mask batch, where each batch “unit” provides C mask references, with C usually being 3. When we apply argmax to each “unit” in the score batch, we are simply filtering the values while preserving the dimensions. It is up to the user to decide whether to eliminate this dimension or retain it as needed.
In addition, a clumsy approach is to use a for-loop to iterate through each group. In this case, as mentioned earlier, the API usage and returned batch will be 1, and the obtained mask will have a shape of (1, 1, W, H).
Full image: https://dl.dropbox.com/s/81suav264sfrx0o/SAM-return-best-filter.jpg