How to correctly extract the CLS token from a Keras Hub ViT backbone, and clarify preprocessor usage and pretraining dataset?

1 day ago 1
ARTICLE AD BOX

I’m working with a Vision Transformer (ViT) backbone from Keras Hub and building my own classification head. My code looks like this:

python

def get_vit_model(model_variant='vit_base', input_shape=(256, 256, 3), num_classes=3, train_base_model=True): preset_path = "/home/ahmed/ct_brain_project/models" back_bone = keras_hub.models.Backbone.from_preset(preset_path) back_bone.trainable = train_base_model inputs = layers.Input(shape=input_shape, name='input_layer') features = back_bone(inputs, training=train_base_model) # Extract CLS token cls_token = features[:, 0, :] # (batch, embed_dim) x = layers.Dense(128, use_bias=False)(cls_token) # rest of code of the classification head model = Model(inputs=inputs, outputs=outputs) return model

From the config I downloaded (vit_base_patch16_224_imagenet from Kaggle), I see:

json

"class_name": "ViTBackbone",
"config": {
"use_class_token": true,
"image_shape": [224, 224, 3],
"patch_size": [16, 16],
"num_layers": 12,
"num_heads": 12,
"hidden_dim": 768,
"mlp_dim": 3072
}

So my questions are :

1- CLS token extraction: Is features[:, 0, :] the correct way to extract the CLS token embedding from the backbone output? i viewed the ViTPatchingAndEmbedding class and i see that

patch_embeddings = ops.concatenate( [class_token, patch_embeddings], axis=1 )

But I am not sure if this class is used in the backbone i downloaded or not.

2- Preprocessor: Since I’m using keras_hub.models.Backbone.from_preset(...) directly, am I correct that no ViTImageClassifierPreprocessor is being applied?

3- Pretraining dataset: The preset I downloaded is named vit_base_patch16_224_imagenet. Is this pretrained on ImageNet‑1k or ImageNet‑21k? (I know Hugging Face has google/vit-base-patch16-224 which is 21k‑pretrained then fine‑tuned on 1k, but I want to confirm for the Keras Hub/Kaggle version.)

Thanks in advance

Read Entire Article