Unsloth - Model Visualization
Model Visualization
from transformers import BatchEncoding, TextStreamer
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/gpt-oss-20b",
dtype=None, # torch.bfloat16, # None for auto detection
max_seq_length=1000,
load_in_4bit=False,
full_finetuning=False,
low_cpu_mem_usage=True,
device_map="cuda", # Explicitly load to CUDA
)
VIsualization Codes
import torch
from rich.tree import Tree
from rich import print as rprint
def visualize_model_structure(model):
# 1. Create root node
tree = Tree(f"ποΈ [bold blue]Model: {getattr(model.config, '_name_or_path', 'Unknown')}[/bold blue]")
# Dictionary to keep track of created nodes: {path_string: rich_tree_node}
node_lookup = {"": tree}
for name, module in model.named_modules():
if name == "": continue
# Split path: 'model.layers.0.self_attn' -> ['model', 'layers', '0', 'self_attn']
parts = name.split('.')
parent_path = ".".join(parts[:-1])
current_part = parts[-1]
# Calculate Size Info
# Get parameter count for this specific module
params_count = sum(p.numel() for p in module.parameters(recurse=False))
# Get shape if it's a leaf layer (like Linear or Embedding)
shape_info = ""
if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
shape_info = f" [yellow]({list(module.weight.shape)})[/yellow]"
elif params_count > 0:
shape_info = f" [green]({params_count:,} params)[/green]"
# 2. Find or Create Node
if parent_path in node_lookup:
parent_node = node_lookup[parent_path]
# Add new node with style and size info
new_node = parent_node.add(f"[bold magenta]{current_part}[/bold magenta]{shape_info}")
node_lookup[name] = new_node
rprint(tree)
# Execution
visualize_model_structure(model)
hereβs the result
ποΈ Model: unsloth/gpt-oss-20b
βββ model
β βββ embed_tokens ([201088, 2880])
β βββ layers
β β βββ 0
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 1
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 2
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 3
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 4
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 5
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 6
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 7
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 8
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 9
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 10
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 11
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 12
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 13
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 14
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 15
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 16
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 17
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 18
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 19
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 20
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 21
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 22
β β β βββ self_attn (64 params)
β β β β βββ q_proj ([4096, 2880])
β β β β βββ k_proj ([512, 2880])
β β β β βββ v_proj ([512, 2880])
β β β β βββ o_proj ([2880, 4096])
β β β βββ mlp
β β β β βββ router ([32, 2880])
β β β β βββ experts (796,538,880 params)
β β β βββ input_layernorm ([2880])
β β β βββ post_attention_layernorm ([2880])
β β βββ 23
β β βββ self_attn (64 params)
β β β βββ q_proj ([4096, 2880])
β β β βββ k_proj ([512, 2880])
β β β βββ v_proj ([512, 2880])
β β β βββ o_proj ([2880, 4096])
β β βββ mlp
β β β βββ router ([32, 2880])
β β β βββ experts (796,538,880 params)
β β βββ input_layernorm ([2880])
β β βββ post_attention_layernorm ([2880])
β βββ norm ([2880])
β βββ rotary_emb
βββ lm_head ([201088, 2880])