Safetensors
GGUF
abliterated
uncensored
conversational
paperscarecrow commited on
Commit
8ddc49c
·
verified ·
1 Parent(s): f93f318

Upload gemma4_31b_abliterator.py

Browse files
Files changed (1) hide show
  1. gemma4_31b_abliterator.py +149 -0
gemma4_31b_abliterator.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import gc
4
+ import json
5
+ import os
6
+ from tqdm import tqdm
7
+ from datasets import load_dataset
8
+ import random
9
+
10
+ # --- CONFIGURATION ---
11
+ MODEL_ID = "google/gemma-4-31B-it" # Adjust if your local path differs
12
+ SAVE_PATH = "./gemma-4-31b-abliterated"
13
+ BATCH_SIZE = 4 # Keep this low to survive the 31B hidden state extraction
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ print(f"[*] Initializing Gemma 4 31B Abliteration Protocol on {DEVICE}...")
17
+
18
+ # --- 1. LOAD MODEL & TOKENIZER ---
19
+ print("[*] Loading Model and Tokenizer (bfloat16)...")
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_ID,
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="auto" # Let accelerate distribute the 62GB across your GPUs
25
+ )
26
+
27
+ # --- 2. DATA PREPARATION ---
28
+
29
+
30
+ print("[*] Downloading HuggingFace datasets...")
31
+
32
+ # Load the datasets
33
+ harmful_dataset = load_dataset('mlabonne/harmful_behaviors')
34
+ harmless_dataset = load_dataset('mlabonne/harmless_alpaca')
35
+
36
+ # Extract the raw text prompts
37
+ # We shuffle and slice 256 samples to keep VRAM extraction manageable but statistically significant
38
+ raw_harmful = random.sample(harmful_dataset['train']['text'], 256)
39
+ raw_harmless = random.sample(harmless_dataset['train']['text'], 256)
40
+
41
+ def format_gemma4_prompts(instructions):
42
+ """Uses the native Gemma 4 chat template with system roles."""
43
+ formatted = []
44
+ for inst in instructions:
45
+ messages = [
46
+ {"role": "system", "content": "You are a helpful assistant."},
47
+ {"role": "user", "content": inst}
48
+ ]
49
+ # Tokenizer handles all the <start_of_turn> control tokens
50
+ formatted.append(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True))
51
+ return formatted
52
+
53
+ print("[*] Formatting prompts with native Gemma 4 Chat Templates...")
54
+ harmful_prompts = format_gemma4_prompts(raw_harmful)
55
+ harmless_prompts = format_gemma4_prompts(raw_harmless)
56
+
57
+ # --- 3. HIDDEN STATE EXTRACTION (VRAM SAFE) ---
58
+ def get_hidden_states(prompts, batch_size=BATCH_SIZE):
59
+ print(f"[*] Extracting hidden states (Batches of {batch_size})...")
60
+ all_hidden_states = []
61
+
62
+ for i in tqdm(range(0, len(prompts), batch_size)):
63
+ batch = prompts[i:i+batch_size]
64
+ inputs = tokenizer(batch, padding=True, return_tensors="pt").to(DEVICE)
65
+
66
+ with torch.no_grad():
67
+ outputs = model(**inputs, output_hidden_states=True)
68
+ # outputs.hidden_states is a tuple of (num_layers + 1) tensors.
69
+ # Shape of each tensor: [batch_size, sequence_length, hidden_dim]
70
+ # We want the last token's state across ALL layers.
71
+
72
+ # Stack to: [num_layers+1, batch, seq, dim]
73
+ stacked_states = torch.stack(outputs.hidden_states)
74
+ # Extract last token: [num_layers+1, batch, dim]
75
+ last_token_states = stacked_states[:, torch.arange(len(batch)), -1, :]
76
+
77
+ # IMMEDIATELY move to CPU float32 to save VRAM
78
+ all_hidden_states.append(last_token_states.cpu().float())
79
+
80
+ del inputs, outputs, stacked_states, last_token_states
81
+ torch.cuda.empty_cache()
82
+ gc.collect()
83
+
84
+ # Concatenate along the batch dimension: [num_layers+1, total_prompts, hidden_dim]
85
+ return torch.cat(all_hidden_states, dim=1)
86
+
87
+ print("\n[*] Processing Harmful Vector Space...")
88
+ harmful_states = get_hidden_states(harmful_prompts)
89
+ print("[*] Processing Harmless Vector Space...")
90
+ harmless_states = get_hidden_states(harmless_prompts)
91
+
92
+ # --- 4. DYNAMIC LAYER HUNTING ---
93
+ print("\n[*] Hunting for the Refusal Vector...")
94
+ mean_harmful = harmful_states.mean(dim=1)
95
+ mean_harmless = harmless_states.mean(dim=1)
96
+
97
+ refusal_directions = mean_harmful - mean_harmless
98
+
99
+ # Find the state index with the highest magnitude
100
+ magnitudes = torch.norm(refusal_directions[1:], dim=1)
101
+ peak_state_idx = torch.argmax(magnitudes).item() + 1
102
+
103
+ print(f"[+] Peak Refusal Mass detected at state index: {peak_state_idx}")
104
+
105
+ # Normalize the refusal vector
106
+ refusal_vector = refusal_directions[peak_state_idx]
107
+ refusal_vector = (refusal_vector / torch.norm(refusal_vector)).to(DEVICE).to(torch.bfloat16)
108
+
109
+ # --- 5. ORTHOGONAL PROJECTION (THE ABLITERATION) ---
110
+ # FIX 1: Safely navigate the Gemma 4 Multimodal Config
111
+ num_layers = model.config.text_config.num_hidden_layers if hasattr(model.config, 'text_config') else model.config.num_hidden_layers
112
+
113
+ # FIX 2: Correct the off-by-one mapping (State index 60 comes from Layer 59)
114
+ target_layer_idx = peak_state_idx - 1
115
+
116
+ print(f"\n[*] Applying Orthogonal Projection starting at Layer {target_layer_idx}...")
117
+
118
+ # FIX 3: Bulletproof dynamic layer discovery for Multimodal models
119
+ def get_transformer_layers(model_obj, target_len):
120
+ for name, module in model_obj.named_modules():
121
+ if name.endswith('layers') and isinstance(module, torch.nn.ModuleList) and len(module) == target_len:
122
+ return module
123
+ return model_obj.model.layers # Fallback
124
+
125
+ transformer_layers = get_transformer_layers(model, num_layers)
126
+
127
+ # Pre-calculate column and row vectors for the linear algebra
128
+ v_col = refusal_vector.unsqueeze(1) # Shape: (5376, 1)
129
+ v_row = refusal_vector.unsqueeze(0) # Shape: (1, 5376)
130
+
131
+ # Abliterate the target layer and up to 4 subsequent layers (capped safely by num_layers)
132
+ for layer_idx in range(target_layer_idx, min(target_layer_idx + 5, num_layers)):
133
+ print(f" -> Abliterating Layer {layer_idx}...")
134
+
135
+ o_proj = transformer_layers[layer_idx].self_attn.o_proj.weight.data
136
+ down_proj = transformer_layers[layer_idx].mlp.down_proj.weight.data
137
+
138
+ # CORRECTED MATH: v_col @ (v_row @ W)
139
+ projection_o = torch.matmul(v_col, torch.matmul(v_row, o_proj))
140
+ transformer_layers[layer_idx].self_attn.o_proj.weight.data -= projection_o
141
+
142
+ projection_down = torch.matmul(v_col, torch.matmul(v_row, down_proj))
143
+ transformer_layers[layer_idx].mlp.down_proj.weight.data -= projection_down
144
+
145
+ # --- 6. CRYSTALLIZATION ---
146
+ print(f"\n[*] Abliteration Complete. Saving uncensored weights to {SAVE_PATH}...")
147
+ model.save_pretrained(SAVE_PATH)
148
+ tokenizer.save_pretrained(SAVE_PATH)
149
+ print("[+] SUCCESS: The 31B Teacher is ready to wake up.")