telcom commited on
Commit
62ff71a
·
verified ·
1 Parent(s): 0a20d74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -180
app.py CHANGED
@@ -8,6 +8,8 @@ from PIL import Image
8
 
9
  import torch
10
  from diffusers import (
 
 
11
  StableDiffusionXLPipeline,
12
  StableDiffusionXLImg2ImgPipeline,
13
  EulerAncestralDiscreteScheduler,
@@ -15,7 +17,7 @@ from diffusers import (
15
  from huggingface_hub import login
16
 
17
  # ============================================================
18
- # GPU decorator (optional)
19
  # ============================================================
20
  try:
21
  import spaces
@@ -24,66 +26,65 @@ except Exception:
24
  def GPU_DECORATOR(fn):
25
  return fn
26
 
27
- from compel import CompelForSDXL
28
-
 
29
  MODEL_ID = "telcom/dee-unlearning-tiny-sd"
30
- REVISION="main"
31
 
32
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
33
  if HF_TOKEN:
34
  login(token=HF_TOKEN)
35
 
36
- # ============================================================
37
- # Detect device
38
- # ============================================================
39
  cuda_available = torch.cuda.is_available()
40
  device = torch.device("cuda" if cuda_available else "cpu")
41
  dtype = torch.float16 if cuda_available else torch.float32
42
 
43
  MAX_SEED = np.iinfo(np.int32).max
44
- MAX_IMAGE_SIZE = 1216 if cuda_available else 768 # CPU smaller
45
 
46
  pipe_txt2img = None
47
  pipe_img2img = None
48
- compel = None
49
  model_loaded = False
50
  load_error = None
51
- fallback_msg = ""
52
-
53
 
54
  # ============================================================
55
- # Load model (txt2img + img2img sharing weights)
56
  # ============================================================
57
  try:
58
  from_pretrained_kwargs = dict(
59
  torch_dtype=dtype,
60
- use_safetensors=False,
61
  )
62
 
63
- if cuda_available:
64
- from_pretrained_kwargs["variant"] = "fp16"
65
-
66
  if HF_TOKEN:
67
  from_pretrained_kwargs["token"] = HF_TOKEN
68
 
69
- # Base txt2img pipeline revision=REVISION,
70
- pipe_txt2img = StableDiffusionXLPipeline.from_pretrained(
71
- MODEL_ID, revision=REVISION, **from_pretrained_kwargs
72
- )
 
 
 
 
 
 
 
 
73
  pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
74
  pipe_txt2img.scheduler.config
75
  )
76
  pipe_txt2img = pipe_txt2img.to(device)
77
 
78
- # Memory opts
79
- try:
80
- pipe_txt2img.enable_vae_slicing()
81
- except Exception:
82
- pass
83
  try:
84
  pipe_txt2img.enable_attention_slicing()
 
85
  except Exception:
86
  pass
 
87
  try:
88
  pipe_txt2img.enable_xformers_memory_efficient_attention()
89
  except Exception:
@@ -91,39 +92,31 @@ try:
91
 
92
  pipe_txt2img.set_progress_bar_config(disable=True)
93
 
94
- # Create img2img pipeline from txt2img components (no extra weights)
95
- pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components)
 
 
 
 
96
  pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
97
  pipe_img2img.scheduler.config
98
  )
99
  pipe_img2img = pipe_img2img.to(device)
100
 
101
- try:
102
- compel = CompelForSDXL(pipe_txt2img, device=str(device))
103
- except TypeError:
104
- compel = CompelForSDXL(pipe_txt2img)
105
-
106
  model_loaded = True
107
 
108
  except Exception as e:
109
  load_error = repr(e)
110
  model_loaded = False
111
 
112
-
113
- if not cuda_available:
114
- fallback_msg = "GPU unavailable. Running in CPU fallback mode (slower, smaller images)."
115
-
116
-
117
  # ============================================================
118
- # Error image
119
  # ============================================================
120
  def _make_error_image(w, h, text):
121
- img = Image.new("RGB", (w, h), (18, 18, 22))
122
- return img
123
-
124
 
125
  # ============================================================
126
- # Inference (txt2img or img2img depending on init_image)
127
  # ============================================================
128
  @GPU_DECORATOR
129
  def infer(
@@ -135,166 +128,97 @@ def infer(
135
  height,
136
  guidance_scale,
137
  num_inference_steps,
138
- init_image, # new: optional image
139
- strength, # new: img2img strength
140
  ):
141
  width = int(width)
142
  height = int(height)
143
- seed = int(seed)
144
 
145
- if not model_loaded or pipe_txt2img is None or pipe_img2img is None or compel is None:
146
- msg = "Model failed to load."
147
- if load_error:
148
- msg += f" (details: {load_error})"
149
- return _make_error_image(width, height, msg), msg
150
 
151
- # Randomize seed if requested
152
  if randomize_seed:
153
  seed = random.randint(0, MAX_SEED)
154
 
155
- if device.type == "cuda":
156
- generator = torch.Generator(device=device).manual_seed(seed)
157
- else:
158
- generator = torch.Generator().manual_seed(seed)
159
 
160
- status = f"Seed: {seed}"
161
- if fallback_msg:
162
- status += f" | {fallback_msg}"
 
 
163
 
164
  try:
165
  with torch.inference_mode():
166
- conditioning = compel(prompt, negative_prompt=negative_prompt)
167
-
168
- common_kwargs = dict(
169
- prompt_embeds=conditioning.embeds,
170
- pooled_prompt_embeds=conditioning.pooled_embeds,
171
- negative_prompt_embeds=conditioning.negative_embeds,
172
- negative_pooled_prompt_embeds=conditioning.negative_pooled_embeds,
173
- guidance_scale=float(guidance_scale),
174
- num_inference_steps=int(num_inference_steps),
175
- generator=generator,
176
- )
177
-
178
- if device.type == "cuda":
179
- with torch.autocast("cuda", dtype=dtype):
180
-
181
- # If init_image is provided, use img2img
182
- if init_image is not None:
183
- image = pipe_img2img(
184
- image=init_image,
185
- strength=float(strength),
186
- **common_kwargs,
187
- ).images[0]
188
- else:
189
- image = pipe_txt2img(
190
- width=width,
191
- height=height,
192
- **common_kwargs,
193
- ).images[0]
194
  else:
195
- if init_image is not None:
196
- image = pipe_img2img(
197
- image=init_image,
198
- strength=float(strength),
199
- **common_kwargs,
200
- ).images[0]
201
- else:
202
- image = pipe_txt2img(
203
- width=width,
204
- height=height,
205
- **common_kwargs,
206
- ).images[0]
207
-
208
- return image, status
209
 
210
  except Exception as e:
211
- msg = f"Error during generation: {type(e).__name__}: {e}"
212
- return _make_error_image(width, height, msg), msg
213
 
214
  finally:
215
  gc.collect()
216
  if device.type == "cuda":
217
  torch.cuda.empty_cache()
218
 
219
-
220
  # ============================================================
221
  # UI
222
  # ============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- CSS = """
225
- body{
226
- background:#000;
227
- color:#fff;
228
- }
229
- """
230
-
231
- with gr.Blocks(title="Text to Image / Image to Image") as demo:
232
-
233
- gr.HTML(f"<style>{CSS}</style>")
234
-
235
- with gr.Column():
236
-
237
- # banner first
238
- if fallback_msg:
239
- gr.Markdown(f"**{fallback_msg}**")
240
-
241
- if not model_loaded:
242
- gr.Markdown(
243
- f"⚠️ **Model failed to load.**\n\nDetails: {load_error}",
244
- elem_classes=["small-note"],
245
- )
246
-
247
- gr.Markdown("## SDXL Generator (txt2img + img2img)")
248
-
249
- prompt = gr.Textbox(
250
- label="Prompt",
251
- placeholder="Enter your prompt...",
252
- lines=2,
253
- )
254
-
255
- # NEW: optional initial image for img2img
256
- init_image = gr.Image(
257
- label="Initial image (optional)",
258
- type="pil",
259
- )
260
-
261
- run_button = gr.Button("Generate")
262
- result = gr.Image(label="Result")
263
- status = gr.Markdown("")
264
-
265
- with gr.Accordion("Advanced Settings", open=False):
266
- negative_prompt = gr.Textbox(label="Negative prompt", value="")
267
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
268
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
269
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
270
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
271
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=7)
272
- num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=40, step=1, value=20)
273
-
274
- # NEW: strength for img2img
275
- strength = gr.Slider(
276
- label="Image strength (for img2img)",
277
- minimum=0.0,
278
- maximum=1.0,
279
- step=0.05,
280
- value=0.7,
281
- )
282
-
283
- run_button.click(
284
- fn=infer,
285
- inputs=[
286
- prompt,
287
- negative_prompt,
288
- seed,
289
- randomize_seed,
290
- width,
291
- height,
292
- guidance_scale,
293
- num_inference_steps,
294
- init_image,
295
- strength,
296
- ],
297
- outputs=[result, status],
298
- )
299
-
300
- demo.queue().launch(ssr_mode=False)
 
8
 
9
  import torch
10
  from diffusers import (
11
+ StableDiffusionPipeline,
12
+ StableDiffusionImg2ImgPipeline,
13
  StableDiffusionXLPipeline,
14
  StableDiffusionXLImg2ImgPipeline,
15
  EulerAncestralDiscreteScheduler,
 
17
  from huggingface_hub import login
18
 
19
  # ============================================================
20
+ # Optional GPU decorator (Spaces)
21
  # ============================================================
22
  try:
23
  import spaces
 
26
  def GPU_DECORATOR(fn):
27
  return fn
28
 
29
+ # ============================================================
30
+ # Config
31
+ # ============================================================
32
  MODEL_ID = "telcom/dee-unlearning-tiny-sd"
33
+ REVISION = "main"
34
 
35
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
36
  if HF_TOKEN:
37
  login(token=HF_TOKEN)
38
 
 
 
 
39
  cuda_available = torch.cuda.is_available()
40
  device = torch.device("cuda" if cuda_available else "cpu")
41
  dtype = torch.float16 if cuda_available else torch.float32
42
 
43
  MAX_SEED = np.iinfo(np.int32).max
44
+ MAX_IMAGE_SIZE = 1216 if cuda_available else 768
45
 
46
  pipe_txt2img = None
47
  pipe_img2img = None
48
+ is_sdxl = False
49
  model_loaded = False
50
  load_error = None
 
 
51
 
52
  # ============================================================
53
+ # Load model (AUTO detect SDXL vs SD)
54
  # ============================================================
55
  try:
56
  from_pretrained_kwargs = dict(
57
  torch_dtype=dtype,
58
+ revision=REVISION,
59
  )
60
 
 
 
 
61
  if HF_TOKEN:
62
  from_pretrained_kwargs["token"] = HF_TOKEN
63
 
64
+ # Try SDXL first
65
+ try:
66
+ pipe_txt2img = StableDiffusionXLPipeline.from_pretrained(
67
+ MODEL_ID, **from_pretrained_kwargs
68
+ )
69
+ is_sdxl = True
70
+ except Exception:
71
+ pipe_txt2img = StableDiffusionPipeline.from_pretrained(
72
+ MODEL_ID, **from_pretrained_kwargs
73
+ )
74
+ is_sdxl = False
75
+
76
  pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
77
  pipe_txt2img.scheduler.config
78
  )
79
  pipe_txt2img = pipe_txt2img.to(device)
80
 
81
+ # Memory optimisations
 
 
 
 
82
  try:
83
  pipe_txt2img.enable_attention_slicing()
84
+ pipe_txt2img.enable_vae_slicing()
85
  except Exception:
86
  pass
87
+
88
  try:
89
  pipe_txt2img.enable_xformers_memory_efficient_attention()
90
  except Exception:
 
92
 
93
  pipe_txt2img.set_progress_bar_config(disable=True)
94
 
95
+ # Create img2img pipeline
96
+ if is_sdxl:
97
+ pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components)
98
+ else:
99
+ pipe_img2img = StableDiffusionImg2ImgPipeline(**pipe_txt2img.components)
100
+
101
  pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
102
  pipe_img2img.scheduler.config
103
  )
104
  pipe_img2img = pipe_img2img.to(device)
105
 
 
 
 
 
 
106
  model_loaded = True
107
 
108
  except Exception as e:
109
  load_error = repr(e)
110
  model_loaded = False
111
 
 
 
 
 
 
112
  # ============================================================
113
+ # Helpers
114
  # ============================================================
115
  def _make_error_image(w, h, text):
116
+ return Image.new("RGB", (w, h), (30, 30, 40))
 
 
117
 
118
  # ============================================================
119
+ # Inference
120
  # ============================================================
121
  @GPU_DECORATOR
122
  def infer(
 
128
  height,
129
  guidance_scale,
130
  num_inference_steps,
131
+ init_image,
132
+ strength,
133
  ):
134
  width = int(width)
135
  height = int(height)
 
136
 
137
+ if not model_loaded:
138
+ return _make_error_image(width, height, "Model not loaded"), load_error
 
 
 
139
 
 
140
  if randomize_seed:
141
  seed = random.randint(0, MAX_SEED)
142
 
143
+ generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
144
 
145
+ common_kwargs = dict(
146
+ guidance_scale=float(guidance_scale),
147
+ num_inference_steps=int(num_inference_steps),
148
+ generator=generator,
149
+ )
150
 
151
  try:
152
  with torch.inference_mode():
153
+ if init_image is not None:
154
+ image = pipe_img2img(
155
+ prompt=prompt,
156
+ negative_prompt=negative_prompt,
157
+ image=init_image,
158
+ strength=float(strength),
159
+ **common_kwargs,
160
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  else:
162
+ image = pipe_txt2img(
163
+ prompt=prompt,
164
+ negative_prompt=negative_prompt,
165
+ width=width,
166
+ height=height,
167
+ **common_kwargs,
168
+ ).images[0]
169
+
170
+ return image, f"Seed: {seed} | {'SDXL' if is_sdxl else 'SD 1.x'}"
 
 
 
 
 
171
 
172
  except Exception as e:
173
+ return _make_error_image(width, height, "Generation error"), str(e)
 
174
 
175
  finally:
176
  gc.collect()
177
  if device.type == "cuda":
178
  torch.cuda.empty_cache()
179
 
 
180
  # ============================================================
181
  # UI
182
  # ============================================================
183
+ with gr.Blocks(title="Text-to-Image / Image-to-Image") as demo:
184
+
185
+ gr.Markdown("## Stable Diffusion Generator")
186
+
187
+ if not model_loaded:
188
+ gr.Markdown(f"⚠️ **Model failed to load**\n\n{load_error}")
189
+
190
+ prompt = gr.Textbox(label="Prompt", lines=2)
191
+ init_image = gr.Image(label="Initial image (optional)", type="pil")
192
+
193
+ run_button = gr.Button("Generate")
194
+ result = gr.Image(label="Result")
195
+ status = gr.Markdown("")
196
+
197
+ with gr.Accordion("Advanced Settings", open=False):
198
+ negative_prompt = gr.Textbox(label="Negative prompt", value="")
199
+ seed = gr.Slider(0, MAX_SEED, value=0, step=1, label="Seed")
200
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
201
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Width")
202
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=32, value=512, label="Height")
203
+ guidance_scale = gr.Slider(0, 20, step=0.1, value=7.5, label="Guidance scale")
204
+ num_inference_steps = gr.Slider(1, 40, step=1, value=20, label="Steps")
205
+ strength = gr.Slider(0.0, 1.0, step=0.05, value=0.7, label="Image strength")
206
+
207
+ run_button.click(
208
+ fn=infer,
209
+ inputs=[
210
+ prompt,
211
+ negative_prompt,
212
+ seed,
213
+ randomize_seed,
214
+ width,
215
+ height,
216
+ guidance_scale,
217
+ num_inference_steps,
218
+ init_image,
219
+ strength,
220
+ ],
221
+ outputs=[result, status],
222
+ )
223
 
224
+ demo.queue().launch(ssr_mode=False)