#/*
#* Copyright 2022 Markus Heimerl, University of Tübingen
#* Licensed under CC BY-NC 4.0
#*
#* ANY USE OF THIS SOFTWARE MUST COMPLY WITH THE
#* CREATIVE COMMONS ATTRIBUTION-NONCOMMERCIAL 4.0 INTERNATIONAL LICENSE
#* AGREEMENTS
#*/

from manim import *
import cv2

# how video should look like
# 3 sections
# 1 define discrete picture
# 2 what is thresholding supposed to do?
# 3 show isodata steps

# script:
# the following shows the isodata automatic thresholding algorithm
# thresholding is used to automatically segment images in fore- and background.
# we will first define a grayscale picture and its histogram mathematically
# and then we will have a look at the algorithm

# rules for video engineering:
# a section is built up like this:
#   I) comment that explains multiple things: 
#       0. the index of the section
#       1. what the section's purpose is
#       2. what objects of previous scenes are being used if any and where they have been declared
#       3. what the script for the section is
#   II) create all objects you are going to use and group them in VGroups if possible (maybe you will use some objects only in unison so they dont even need their own variable names)
#   III) make THE LEAST AMOUNT of self.play() calls POSSIBLE and end the section (e.g. if multiple are in a loop its fine but usually end section after one self.play() call)

# section documentation template:
#################################
## SECTION #X
#################################
## PURPOSE: By showing ... the viewer is ...
## CONTENT: This section defines/shows/...
## EXISTING OBJECTS USED: img_nature[from #5], arrow_std[from #1] ...
## SCRIPT: Let us now look at the definition of ...
#################################


# chapters should be different scenes
# and combine them with 
#class video_title(Scene):
#    def construct(self):
#        scene1.construct()
#        scene2.construct()
# or ffmpeg if it doesnt work

global_mu0 = 0
global_mu1 = 0

img_nature = cv2.imread("nature.png", cv2.IMREAD_GRAYSCALE)

def isodata_one_iteration(img, q):
global global_mu0
global global_mu1
histr = [j for sub in cv2.calcHist([img],[0],None,[256],[0,256]) for j in sub]
q = int(q)

mu0_upper_sum = 0
mu0_lower_sum = 0
for z in range(q):
mu0_upper_sum += histr[z] * z
mu0_lower_sum += histr[z]

if mu0_lower_sum == 0:
return isodata_one_iteration(img,q+1)
else:
mu0 = mu0_upper_sum / mu0_lower_sum

mu1_upper_sum = 0
mu1_lower_sum = 0
for z in range(q+1,255,1):
mu1_upper_sum += histr[z] * z
mu1_lower_sum += histr[z]

if mu1_lower_sum == 0:
return isodata_one_iteration(img,q-1)
else:
mu1 = mu1_upper_sum / mu1_lower_sum

global_mu0 = round(mu0, 4)
global_mu1 = round(mu1, 4)

f_q = 0.5 * (mu0 + mu1)

return int(f_q)


def apply_threshold(img, t):
img_copy = img.copy()
for l in range(len(img_copy)):
for p in range(len(img_copy[l])):
	if img_copy[l][p] >= t:
		img_copy[l][p] = 255
	else:
		img_copy[l][p] = 0
return img_copy

class isodata(Scene):
def construct(self):

#################################
## SECTION #1
#################################
## PURPOSE: This section is an itroductory slide to the scene. It is supposed to show in one
##          view what the scene is about.
## CONTENT: It shows side by side the nature image and the result of the isodata algorithm on it
## EXISTING OBJECTS USED: -
## SCRIPT: We will have a look at the isodata thresholding algorithm now. This will be done by, first, defining 
##         a discrete picture formally, second, defining thresholding as a point operation on a image, third,
##         defining the histogram as a function of magnitutes of sets and lastly walking through the algorithm
##         and showing its iterative and convergent nature.
#################################
img = ImageMobject("nature.png").scale(0.75).move_to(LEFT*3)
title = Tex(r"\underline{Isodata Thresholding Algorithm}").move_to(UP*2)
img_isodata = ImageMobject("nature_isodata.png").scale(0.75).move_to(RIGHT * 3)
arrow = Arrow(start=LEFT, end=RIGHT)
self.play(Write(title),FadeIn(img),FadeIn(img_isodata),FadeIn(arrow), runtime=3)
self.wait(3)
self.play(FadeOut(title), FadeOut(img_isodata), FadeOut(arrow))

self.next_section()

# image definition
line1 = MathTex(r"D \subseteq \mathbb{N}^2","\quad and \quad G","\subseteq \mathbb{N}")
line2 = MathTex(r"B: D \longrightarrow G")
line3 = MathTex(r"B: \{1,...,M\} \times \{1,...,N\} \longrightarrow \{0,...,g_{max}-1\}") #, so (m,n) \mapsto B(m,n)
VGroup(line1, line2, line3).arrange(DOWN).move_to(DOWN*1.5)
self.play(img.animate.move_to(UP*1.5))
self.play(Write(line1))
self.play(Write(line2))
self.play(Write(line3))
self.wait()
self.play(FadeOut(line1), FadeOut(line2), FadeOut(line3))

self.next_section()

# showing example picture and sample grayscale values
self.play(img.animate.move_to(UP))
line3.move_to(DOWN)
self.play(FadeIn(line3))
replacement1 = MathTex(r"B: \{1,...,512\} \times \{1,...,681\} \longrightarrow \{0,...,256-1\}").move_to(DOWN)
self.play(ReplacementTransform(line3,replacement1))
self.play(img.animate.move_to(LEFT*2 + UP))
square1 = Square(side_length=0.01, color=PURE_RED).move_to(LEFT + UP + LEFT*1.78 + UP*0.24)
sample1 = MathTex(r"B(221,200) = 45").move_to(RIGHT*2 + UP)
self.play(FadeIn(square1), FadeIn(sample1))
self.wait()
square2 = Square(side_length=0.01, color=PURE_RED).move_to(LEFT + UP + LEFT*0.9 - UP*0.3)
sample2 = MathTex(r"B(311,361) = 179").move_to(RIGHT*2 + UP)
self.play(ReplacementTransform(square1,square2), ReplacementTransform(sample1,sample2))
self.wait()
self.play(FadeOut(img), FadeOut(square2), FadeOut(sample2), FadeOut(replacement1))

self.next_section()

# showing the thresholding transformation: grayscale to a binary image
title_thresh = Tex(r"\underline{Thresholding}").move_to(UP*2)
thresh = MathTex(r"B^{'}(m,n) =\left\{ \begin{array}{ c l }g_{max}-1 & \quad \textrm{if } B(m,n) \geq q \\ 0 & \quad \textrm{otherwise} \end{array} \right.")
thresh170 = MathTex(r"B^{'}(m,n) =\left\{ \begin{array}{ c l }g_{max}-1 & \quad \textrm{if } B(m,n) \geq 170 \\ 0 & \quad \textrm{otherwise} \end{array} \right.").move_to(UP*2)
thresh15 = MathTex(r"B^{'}(m,n) =\left\{ \begin{array}{ c l }g_{max}-1 & \quad \textrm{if } B(m,n) \geq 15 \\ 0 & \quad \textrm{otherwise} \end{array} \right.").move_to(UP*2)
img_170 = ImageMobject("nature_170.png").scale(0.75).move_to(RIGHT * 3+DOWN)
img_15 = ImageMobject("nature_15.png").scale(0.75).move_to(RIGHT * 3+DOWN)
arrow = Arrow(start=LEFT+DOWN, end=RIGHT+DOWN)

self.play(Write(thresh), Write(title_thresh))
self.wait(3)
img.move_to(LEFT*3+DOWN)
self.play(FadeOut(title_thresh),ReplacementTransform(thresh, thresh170), FadeIn(img),FadeIn(img_170),FadeIn(arrow))
self.wait(3)
self.play(ReplacementTransform(img_170,img_15),ReplacementTransform(thresh170, thresh15))
self.wait(3)
self.play(FadeOut(img_15), FadeOut(arrow), FadeOut(thresh15))


self.next_section()

# showing histogram and histogram function
title_hist = Tex(r"\underline{Histogram}").move_to(UP*2.5)
img_cv2 = cv2.imread("nature.png") 
histr = [j for sub in cv2.calcHist([cv2.resize(img_cv2, (int(img_cv2.shape[1] * 75 / 100), int(img_cv2.shape[0] * 75 / 100)), interpolation = cv2.INTER_AREA)],[0],None,[256],[0,256]) for j in sub]
histr[-1] = 5
chart = BarChart(values=histr,
	x_length=8,
	y_range=[0, max(histr)+30, 500], 
	y_axis_config={"include_tip": True},
	x_axis_config={"include_ticks": False, "unit_size": 0.5, "include_tip": True},
	bar_stroke_width=1
	).move_to(RIGHT * 3).scale(0.5)
	
arrow = Arrow(start=LEFT+LEFT*0.1, end=RIGHT-RIGHT*0.1)
x_axes_label = MathTex(r"g \in G").scale(0.5).next_to(chart, DOWN*0.25+RIGHT*0.01)
y_axes_label = MathTex(r"H(g)").scale(0.5).next_to(chart, LEFT*0.01+UP*0.2)
self.play(img.animate.move_to(LEFT * 3), FadeIn(arrow))
self.play(Write(title_hist), Create(chart), Write(x_axes_label), Write(y_axes_label), run_time=2)
self.wait()
self.play(Group(img, arrow, chart, x_axes_label, y_axes_label).animate.move_to(UP), title_hist.animate.move_to(UP*3))
hg = MathTex(r"H(g)=|\{(m,n): B(m,n)=g\}|").scale(0.75).move_to(DOWN*1.5) 
self.play(Write(hg))
self.wait(2)
self.play(FadeOut(Group(img, arrow, chart, x_axes_label, y_axes_label, hg, title_hist)))

self.next_section()

# write down the algorithm
# these should either be a group or just one aligned latex string
qn = MathTex(r"q_{n+1} = F(q)")
f = MathTex(r"F(q) = \frac{\mu_{0}(q) + \mu_{1}(q)}{2}")
mu = MathTex(r"\mu_{0}(q) = \frac{\sum^{q}_{z = 0}z \cdot H(z)}{\sum^{q}_{z = 0}H(z)} \quad \mu_{1}(q) = \frac{\sum^{g_{max}-1}_{z = q+1}z \cdot H(z)}{\sum^{g_{max}-1}_{z = q+1}H(z)}")
VGroup(qn, f, mu).arrange(DOWN)
self.play(Write(qn))
self.play(Write(f))
self.play(Write(mu))
self.wait()

self.next_section()

# show the iterative and convergent nature of the algorithm
# these should either be a group or just one aligned latex string
q0 = MathTex(r"q_0 = 170")
f0 = MathTex(r"F(170) = \frac{\mu_{0}(170) + \mu_{1}(170)}{2}")
mu0 = MathTex(r"\mu_{0}(170) = \frac{\sum^{170}_{z = 0}z \cdot H(z)}{\sum^{170}_{z = 0}H(z)} \quad \mu_{1}(170) = \frac{\sum^{255}_{z = 171}z \cdot H(z)}{\sum^{255}_{z = 171}H(z)}")
VGroup(q0, f0, mu0).arrange(DOWN).scale(0.5).move_to(UP*2)

self.play(VGroup(qn, f, mu).animate.scale(0.5).move_to(UP*2))
arrow = Arrow(start=LEFT+DOWN, end=RIGHT+DOWN)
img.move_to(LEFT*3+DOWN)
self.play(FadeIn(img),FadeIn(arrow))
self.play(ReplacementTransform(qn,q0),ReplacementTransform(f,f0),ReplacementTransform(mu,mu0))
img_170 = ImageMobject(apply_threshold(img_nature, 170)).scale(0.75).move_to(RIGHT * 3+DOWN)
self.play(FadeIn(img_170))
self.wait()

# from here on out the isodata should automatically run and the animation should automatically update with it
# first of all we need the isodata to run on a picture for just one iteration
q = 170
q_new = -1
iteration = 1
q_old = q0
f_old = f0
mu_old = mu0
img_thresh_old = img_170
while True:
	q_new = isodata_one_iteration(img_nature, q)
	
	
	# these should either be a group or just one aligned latex string
	q_loop = MathTex(r"q_{" + str(iteration) + "} = " + str(q_new)).scale(0.5).move_to(q_old.get_center())
	f_loop = MathTex(r"F(" + str(q) + ") = \\frac{\mu_{0}(" + str(q) + ") + \mu_{1}(" + str(q) + ")}{2} = " + str(q_new)).scale(0.5).move_to(f_old.get_center())
	mu_loop = MathTex(r"\mu_{0}(" + str(q) + ") = \\frac{\sum^{" + str(q) + "}_{z = 0}z \cdot H(z)}{\sum^{" + str(q) + "}_{z = 0}H(z)} = " + str(global_mu0) + " \quad \mu_{1}(" + str(q) + ") = \\frac{\sum^{255}_{z = " + str(q+1) + "}z \cdot H(z)}{\sum^{255}_{z = " + str(q+1) + "}H(z)} = " + str(global_mu1)).scale(0.5).move_to(mu_old.get_center())
	img_thresh = ImageMobject(apply_threshold(img_nature, q_new)).scale(0.75).move_to(RIGHT * 3+DOWN)
	
	if q_new == q:
		break;
	else:
		q = q_new
	
	self.play(ReplacementTransform(q_old,q_loop), ReplacementTransform(img_thresh_old, img_thresh), ReplacementTransform(f_old, f_loop), ReplacementTransform(mu_old, mu_loop))
	
	q_old = q_loop
	f_old = f_loop
	mu_old = mu_loop
	img_thresh_old = img_thresh
	
	iteration += 1
	self.wait()
	

self.play(FadeOut(q_old, mu_old, img, arrow, f_old), img_thresh_old.animate.scale(1.0).move_to(ORIGIN))
self.wait(5)
self.play(FadeOut(img_thresh_old))