# Math 537, Introduction to Numerical Analysis 1, Fall 2013
# Final Project
# File created November 1st, 2013
# Modified for distribution May 2018 (Python 2 -> Python 3)
# Modified for Pygame 2 June 2022
#--------------------------------------------------------------------------

import sys

if sys.version_info[0] < 3:
    print("This script requires Python version 3.0 or higher")
    sys.exit(1)

import pygame
import numpy as np
import matplotlib.pyplot as plt
from pygame.locals import *

HEIGHT = 640
SCREEN_SIZE = (HEIGHT, HEIGHT)
ORBIT_MODE = 0
MAGNIFY_MODE = 1
CYCLE_MODE = 2
FEIGENBAUM_MODE = 3
DISTANCE_MODE = 4
DIVERGENCE_MAP = 0
HANKEL_MAP = 1
CONVERGENCE_MAP = 2
PERIOD_MAP = 3
ANGLE_MAP = 4
CYCLE_MAP = 5
NO_GRID = 0
CARTESIAN_GRID = 1
POLAR_GRID = 2
EPSILON = 10**(-6)
NUMBER_KEYS = (K_0, K_1, K_2, K_3, K_4, K_5, K_6, K_7, K_8, K_9)

MODES = ("Orbit", "Magnify", "Cycle", "Feigenbaum", "Distance")
MAPS = ("Divergence", "Hankel", "Convergence", "Period", "Angle", "Cycle")
GRIDS = ("No", "Cartesian", "Polar")

#--------------------------------------------------------------------------
#
# Color maps
#
#--------------------------------------------------------------------------

def make_palette(colors, nodes):
    '''Interpolates colors between given nodes to create a color palette'''
    palette = []
    previous = 0
    for index, n in enumerate(nodes):
        r1, g1, b1 = colors[index]
        r2, g2, b2 = colors[index + 1]
        rdiff = r2 - r1
        gdiff = g2 - g1
        bdiff = b2 - b1
        gap = n - previous
        palette += [(r1 + i*rdiff//gap, g1 + i*gdiff//gap, b1 + i*bdiff//gap) for i in range(gap + 1)]
        previous = n
    # Added 6/7/2022
    # Palettes for pygame 2 are now 24 bit rgb integers, not (r, g, b) tuples
    red = (2**16)*np.array([item[0] for item in palette])
    green = (2**8)*np.array([item[1] for item in palette])
    blue = np.array([item[2] for item in palette])
    palette = red + green + blue
    # Finished adding 6/7/2022
    return palette

def make_palette_cycle(initial, colors):
    '''Create a color palette by cycling given colors, one per level'''
    ncolors = len(colors)
    # Added 6/7/2022
    # Palettes for pygame 2 are now 24 bit rgb integers, not (r, g, b) tuples
    palette = [initial] + [colors[i%ncolors] for i in range(1, 256)]
    red = (2**16)*np.array([item[0] for item in palette])
    green = (2**8)*np.array([item[1] for item in palette])
    blue = np.array([item[2] for item in palette])
    palette = red + green + blue
    # Finished adding 6/7/2022
    # return [initial] + [colors[i%ncolors] for i in range(1, 256)]
    return palette

blackwhite = make_palette([(0,0,0),
                           (255, 255, 255),
                           (255, 255, 255)],
                          [1, 255])
heatmap = make_palette([(0, 0, 0),
                         (0,0,160), # Dark blue
                         (0, 255, 255), # Light blue
                         (128, 255, 128), # Light green
                         (255, 255, 0), # Yellow
                         (255, 128, 64), #Light orange
                         (255, 128, 0)], # Orange
                        [2, 16, 32, 64, 128, 255])
jewel = make_palette([(0, 0, 0),
                      (0, 0, 128),
                      (128, 0, 128),
                      (221, 155, 34),
                      (0, 64, 0),
                      (0, 0, 255),
                      (0, 0, 128),
                      (0, 64, 0),
                      (221, 155, 34),
                      (128, 0, 128),
                      (0, 0, 255)],
                     [2, 4, 6, 8, 12, 16, 32, 64, 128, 255])
blackyellow = make_palette([(0, 0, 0),
                         (0, 0, 0),
                         (255,255,0),
                         (0, 0, 0)],
                        [10, 64, 255])
redpurple = make_palette([(0, 0, 0),
                          (240, 160, 20),
                          (231, 55, 24),
                          (78,47,170),
                          (22, 9, 151),
                          (240, 160, 20),
                          (231, 55, 24),
                          (78,47,170),
                          (22, 9, 151),
                          (255, 0, 0)],
                         [1, 2, 4, 8, 16, 32, 64, 128, 255])
indi = make_palette([(0, 0, 255),
                     (27, 198, 254),
                     (145, 102, 255),
                     (195, 23, 255),
                     (71, 20, 205)],
                    [64, 128, 192, 255])
indi2 = make_palette([(0, 0, 0),
                      (55, 16, 243),
                      (255, 0, 77),
                      (196, 0, 166),
                      (0, 255, 147),
                      (0, 128, 255)],
                     [2, 32, 64, 128, 255])
rainbow = make_palette([(0, 0, 0),
                        (0, 0, 255), # Blue
                        (0, 255, 0), # Green
                        (255, 255, 0), # Yellow
                        (255, 128, 0), # Orange
                        (255, 0, 0), # Red
                        (128, 0, 255), # Purple
                        (0, 0, 255)], # Blue
                        [1, 43, 85, 127, 170, 213, 255])
ocean = make_palette_cycle((0, 0, 0),
                             [(30, 144, 255),
                              (0, 128, 128),
                              (0, 100, 0),
                              (50, 205, 50),
                              (0, 255, 127),
                              (64, 100, 205),
                              (0, 191, 255),
                              (150, 80, 224),
                              (148, 0, 211),
                              (199, 21, 133),
                              (138, 43, 226),
                              (0, 255, 127),
                              (0, 206, 209)])

#--------------------------------------------------------------------------
#
# Class to hold all the data
#
#--------------------------------------------------------------------------

class Mandelbrot():

    def __init__(self):
        self.mode = MAGNIFY_MODE
        self.map = DIVERGENCE_MAP
        self.cycle = True
        self.paused = False

        self.width = HEIGHT
        self.height = HEIGHT

        # Display coordinates for iteration number and image cursor
        self.text_x = 30
        self.text_y = 20
        self.cursor_x = HEIGHT - 70
        self.cursor_y = 20

        # Magnification rectangle
        self.startx = 0
        self.starty = 0
        self.endx = 0
        self.endy = 0
        self.magnify = False

        # Iterations
        self.itermax = 10000
        self.iter = 0

        # Orbits
        self.orbit = []
        self.zorbit = []
        self.zc = 0

        # Colors
        self.palettes = [blackwhite, heatmap, jewel, rainbow,
                         blackyellow, redpurple, indi, indi2, ocean]
        self.palette = heatmap
        self.color = [0, 0, 128]
        self.word_color = [255, 255, 0]
        self.rect_color = [255, 255, 0]
        self.line_color = [255, 255, 0]
        self.grid_color = [127, 127, 127]

        # Display options
        self.levels = 256
        self.normalize = False
        self.grid = NO_GRID
        self.origin = ((HEIGHT * 2)//3, HEIGHT//2)
        self.current_pos = (0, 0)

        # Fonts
        self.word_font = pygame.font.SysFont("courant", 16, True)
        self.button_font = pygame.font.SysFont("courant", 28, True)

        # Bounding rectangle for the set
        self.initialize_area()

        # image stores the current image
        self.initialize_image()

        pygame.surfarray.use_arraytype('numpy')

    def add_palette(self, palette):
        self.palettes = [palette] + self.palettes
        self.palette = palette

    def initialize_area(self):
        self.xmin = -2.
        self.xmax = 1.
        self.ymin = -1.5
        self.ymax = 1.5

    def update_title(self):
        map = MAPS[self.map]
        mode = MODES[self.mode]
        title = "Mandelbrot Explorer by Adam Cunningham: {} map, {} mode".format(map, mode)
        pygame.display.set_caption(title)

    #------------------------------------------------------------------------
    # Initialize the image data. This is done each time the map type is
    # changed or the image window is changed through reset or magnification.
    # One dimensional arrays are used for integer indexing of just the points
    # which have not diverged.
    # self.ix: 1D int array holding x indexes of "currently active" points.
    # self.iy: 1D int array holding y indexes of "currently active" points.
    # self.z: 1D complex array holding the current z values after iterating
    # self.c: 1D integer array holding the initial c value
    # self.img is a 2D integer array holding the image for the current map.
    #------------------------------------------------------------------------

    def initialize_image(self):
        self.update_title()
        self.iter = 0
        n = self.width
        m = self.width
        ix, iy = np.mgrid[0:n, 0:m]
        x = np.linspace(self.xmin, self.xmax, self.width)[ix]
        # Array image is stored upside down for correct display
        y = np.linspace(self.ymax, self.ymin, m)[iy]
        # c holds the complex plane - the parameter space
        c = x+complex(0,1)*y
        del x, y
        # img holds the number of iterations needed for divergence
        img = np.zeros(c.shape, dtype=int)
        ix.shape = n*m
        iy.shape = n*m
        c.shape = n*m
        self.ix = ix
        self.iy = iy
        self.c = c
        # z holds the results of the latest iteration. Start at c.
        self.z = np.copy(c)
        # Initialize the internal map - the first closest point is c itself
        self.closest = np.abs(c)
        self.previous = np.zeros_like(c)
        self.saved = [self.previous]
        self.angles = np.zeros(c.shape, dtype=float)
        self.img = img

    #------------------------------------------------------------------------
    # Functions for updating the image using different maps
    #------------------------------------------------------------------------

    def update(self):
        ''' Perform one iteration of the currently active mode'''
        if self.paused == True:
            return
        elif self.mode == MAGNIFY_MODE:
            if self.map == DIVERGENCE_MAP:
                self.update_divergence()
            elif self.map == HANKEL_MAP:
                self.update_hankel()
            elif self.map == CONVERGENCE_MAP:
                self.update_convergence()
            elif self.map == PERIOD_MAP:
                self.update_period()
            elif self.map == ANGLE_MAP:
                self.update_angle()
            elif self.map == CYCLE_MAP:
                self.update_cycle()
        elif self.mode == ORBIT_MODE:
            self.update_orbit()
        elif self.mode == CYCLE_MODE:
            self.cycle_palette()

    #------------------------------------------------------------------------
    # Update image functions
    #------------------------------------------------------------------------

    def update_z(self):
        '''Update z to the next point in the orbit, given by z^2 + c, and
        return the indexes of the points which have diverged.'''
        z = self.z
        # Update z to the next point in the orbit, given by z^2 + c
        np.multiply(z, z, out=z)
        np.add(z, self.c, out=z)
        # Points have diverged when the absolute value exceeds 2
        return np.abs(z) > 2.0

    def update_plane_data(self, remaining):
        '''Update plane data to just the points which have not yet diverged'''
        self.z = self.z[remaining]
        self.ix, self.iy = self.ix[remaining], self.iy[remaining]
        self.c = self.c[remaining]
        self.iter += 1

    def update_divergence(self):
        '''Updates the divergence map by one iteration'''
        diverged = self.update_z()
        # The image holds the number of iterations needed to diverge
        self.img[self.ix[diverged], self.iy[diverged]] = self.iter + 1
        self.update_plane_data(np.logical_not(diverged))

    def update_hankel(self):
        '''Updates the hankel map by one iteration'''
        ix, iy = self.ix, self.iy
        saved = self.saved
        diverged = self.update_z()
        # Check which points have converged to the previous point in orbit
        converged = np.abs(saved[-1] - self.z) < EPSILON
        # Compare with all previous points in orbit
        for s in saved[:-1]:
            np.logical_or(converged, np.abs(s - self.z) < EPSILON, out=converged)
        # Any points which have diverged get set to zero in the image
        self.img[ix[diverged], iy[diverged]] = 0
        # Any points which have converged get set to i + 1 in the image
        self.img[ix[converged], iy[converged]] = self.iter + 1
        # Points which have either diverged or converged are removed
        remain = np.logical_not(np.logical_or(diverged, converged))
        self.update_plane_data(remain)
        for j, s in enumerate(saved):
            saved[j] = s[remain]
        self.saved.append(np.copy(self.z))

    def update_cycle(self):
        '''Updates the cycle map by one iteration'''
        ix, iy = self.ix, self.iy
        saved = self.saved
        diverged = self.update_z()
        # Check which points have converged to the previous point in orbit
        converged = np.abs(saved[-1] - self.z) < EPSILON
        # Compare with all previous points in orbit
        i = self.iter
        for j, s in enumerate(saved[:-1]):
            matched = np.abs(s - self.z) < EPSILON
            # Points which converge get set to the period of the cycle
            self.img[ix[matched], iy[matched]] = i + 1 - j
        # Any points which have diverged get set to zero in the image
        self.img[ix[diverged], iy[diverged]] = 0
        # Convergence to a fixed point is represented as a cycle of period 1
        self.img[ix[converged], iy[converged]] = 1
        # Points which have either diverged or converged are removed
        remain = np.logical_not(np.logical_or(diverged, converged))
        self.update_plane_data(remain)
        for j, s in enumerate(saved):
            saved[j] = s[remain]
        self.saved.append(np.copy(self.z))

    def update_period(self):
        '''Updates the period map by one iteration'''
        ix, iy = self.ix, self.iy
        closest = self.closest
        diverged = self.update_z()
        # Find the distance of the latest point in orbit from c
        distance = np.abs(self.c - self.z)
        closer = distance < closest
        # If latest point in orbit is closer than previous, update image
        self.img[ix[closer], iy[closer]] = self.iter + 1
        # Save the new minimum distance
        np.minimum(distance, closest, out=closest)
        # Points which just diverged get set to zero in the image
        self.img[ix[diverged], iy[diverged]] = 0
        # remain is an array showing the remaining points
        remain = np.logical_not(diverged)
        self.closest = closest[remain]
        self.update_plane_data(remain)

    def update_convergence(self):
        '''Updates the convergence map by one iteration'''
        ix, iy = self.ix, self.iy
        closest = self.closest
        diverged = self.update_z()
        # Save minimum distance of any point from c
        np.minimum(np.abs(self.z - self.c), closest, out=closest)
        # Points which just diverged get set to zero in the image
        self.img[ix[diverged], iy[diverged]] = 0
        # remain is an array showing the remaining points
        remain = np.logical_not(diverged)
        self.closest = closest[remain]
        self.img[ix[remain], iy[remain]] = np.ceil(self.closest * self.levels)
        self.update_plane_data(remain)

    def update_angle(self):
        '''Updates the angle map by one iteration'''
        ix, iy = self.ix, self.iy
        previous_line = self.previous - self.z
        self.previous[:] = self.z
        diverged = self.update_z()
        # Keep cumulative total of the angles in the orbit. The angle is the
        # arg of the ratio of the previous orbital segment to the current one.
        self.angles += np.abs(np.angle(previous_line/(self.z - self.previous), deg=True))
        # Find the distance of the latest point in orbit from c
        distance = np.abs(self.c - self.z)
        closer = distance < self.closest
        # If latest point in orbit is closer than previous, update image
        # Show sum of interior angles of orbit
        self.img[ix[closer], iy[closer]] = np.maximum(1, np.rint(self.angles[closer]))
        # Save the new minimum distance
        np.minimum(distance, self.closest, out=self.closest)
        # Points which just diverged get set to zero in the image
        self.img[ix[diverged], iy[diverged]] = 0
        # remain is an array showing the remaining points
        remain = np.logical_not(diverged)
        self.update_plane_data(remain)
        self.closest = self.closest[remain]
        self.previous = self.previous[remain]
        self.angles = self.angles[remain]

    #------------------------------------------------------------------------
    # Cycle through the palette
    #------------------------------------------------------------------------

    def cycle_palette(self):
        ''' Perform one iteration of currently active mode'''
        palette = self.palette
        # If cycling forward
        if self.cycle:
            if isinstance(palette, tuple):
                self.palette = (palette[-1],) + palette[:-1]
            else:
                # Modified 6/7/2022 - Palettes are now numpy arrays
                # self.palette = [palette[-1]] + palette[:-1]
                last = palette[-1]
                self.palette[1:] = self.palette[:-1]
                self.palette[0] = last
        # Else must be cycling backward
        else:
            if isinstance(palette, tuple):
                self.palette = palette[1:] + (palette[1],)
            else:
                # Modified 6/7/2022 - Palettes are now numpy arrays
                # self.palette = palette[1:] + [palette[1]]
                first = palette[0]
                self.palette[:-1] = self.palette[1:]
                self.palette[-1] = first

    #------------------------------------------------------------------------
    # Render the image and associated information
    #------------------------------------------------------------------------

    def render(self, surface):
        # Draw the latest image to the surface
        if self.normalize:
            # Normalize the image
            image = np.copy(self.img)
            m = image.max()
            image *= 255
            image //= m
            pygame.surfarray.blit_array(surface, image)
        else:
            # Added 6/7/2022 to manually convert image to colors
            # pygame.surfarray.blit_array(surface, self.img)
            pygame.surfarray.blit_array(surface, self.palette[self.img % 256])
        # Draw the magnifying rectangle
        if self.magnify == True:
            r = Rect(self.startx, self.starty,
                        self.endx - self.startx, self.endy - self.starty)
            pygame.draw.rect(surface, self.rect_color, r, 1)
        # Draw the overlying grid
        if self.grid == CARTESIAN_GRID:
            intervals = 12
            for i in range(1, intervals):
                x = (i * HEIGHT)//intervals
                pygame.draw.line(surface, self.grid_color, (x, 0), (x, HEIGHT))
                pygame.draw.line(surface, self.grid_color, (0, x), (HEIGHT, x))
        elif self.grid == POLAR_GRID:
            origin_x, origin_y = self.origin
            # Draw the lines of constant radius
            for i in range(1, 10):
                radius = (i * HEIGHT)//12
                pygame.draw.circle(surface, self.grid_color, self.origin, radius, 1)
            # Draw the lines of constant angle
            lines = 6
            radius = (HEIGHT * 2)//3
            pygame.draw.line(surface, self.grid_color, (0, origin_y), (HEIGHT, origin_y))
            for i in range(1, lines):
                theta = (i*np.pi)/lines
                end_y = origin_y + (radius* np.sin(theta))
                end_x = origin_x + (radius* np.cos(theta))
                pygame.draw.line(surface, self.grid_color, self.origin,
                                 (end_x, end_y))
                end_y = origin_y - (radius* np.sin(theta))
                pygame.draw.line(surface, self.grid_color, self.origin,
                                 (end_x, end_y))
        # Update the iteration number on the screen
        text_surface = self.button_font.render(str(self.iter), False, self.word_color)
        surface.blit(text_surface, (self.text_x, self.text_y))
        # Update the cursor image number on the screen
        pixel_value = self.img[self.current_pos[0], self.current_pos[1]]
        text_surface = self.button_font.render(str(pixel_value), False, self.word_color)
        surface.blit(text_surface, (self.cursor_x, self.cursor_y))
        # Draw the orbit
        if self.mode == ORBIT_MODE:
            if (len(self.orbit) > 1):
                pygame.draw.lines(surface, self.line_color, False, self.orbit)
        # Draw the line along which the Feigenbaum disgram is created
        elif self.mode == FEIGENBAUM_MODE:
            pygame.draw.line(surface, self.line_color, self.origin, self.current_pos)
        # If the palette is being cycled, make sure it shows up
        elif self.mode == CYCLE_MODE:
            # Removed 6/7/2022 - no longer using screen palettes
            # screen.set_palette(self.palette)
            # Added 6/7/2022 to manually convert image to colors
            pygame.surfarray.blit_array(surface, self.palette[self.img % 256])

    #------------------------------------------------------------------------
    # Magnify mode
    #------------------------------------------------------------------------

    def start_magnify(self, x, y):
        '''Save the start of a magnify rectangle'''
        self.magnify = True
        self.startx = x
        self.starty = y
        self.endx = x
        self.endy = y

    def end_magnify(self, x, y):
        '''Finish the magnifying rectangle'''
        if self.magnify == False:
            return
        self.magnify = False

        # Dimensions of the area to magnify
        left = min(x, self.startx)
        right = max(x, self.startx)
        top = min(y, self.starty)
        bottom = max(y, self.starty)

        pixel_width = min(right - left, bottom - top)
        if pixel_width < 2:
            self.magnify = False
            return
        old_width = self.xmax - self.xmin
        new_width = pixel_width * old_width/HEIGHT

        # Update the new area of the image
        self.xmin = self.xmin + (left * old_width/HEIGHT)
        self.xmax = self.xmin + new_width
        self.ymax = self.ymax - (top * old_width/HEIGHT)
        self.ymin = self.ymax - new_width
        self.initialize_image()
        pygame.display.set_caption('xmin = {:16.14f}, ymin = {:16.14f}, \
        width={:16.14f}'.format(self.xmin, self.ymin, new_width))

    def move_magnify(self, x, y):
        '''Update the magnifying rectangle'''
        if self.magnify == False:
            return
        left = min(x, self.startx)
        right = max(x, self.startx)
        top = min(y, self.starty)
        bottom = max(y, self.starty)
        width = min(right - left, bottom - top)
        # Restrict the magnifying rectangle to be a square
        self.startx = left
        self.starty = top
        self.endx = left + width
        self.endy = top + width
        return

    #------------------------------------------------------------------------
    # Conversions between screen coordinates and points in complex plane
    #------------------------------------------------------------------------

    def xy_to_z(self, x, y):
        '''Convert screen coordinates into a complex number'''
        width = self.xmax - self.xmin
        zx = self.xmin + x * width/HEIGHT
        zy = self.ymax - y * width/HEIGHT
        return complex(zx, zy)

    def z_to_xy(self, z):
        '''Convert complex number into screen coordinates'''
        width = self.xmax - self.xmin
        zx = np.real(z)
        zy = np.imag(z)
        x = int(((zx - self.xmin) * HEIGHT)/width)
        y = int(((self.ymax - zy) * HEIGHT)/width)
        return (x, y)

    #------------------------------------------------------------------------
    # Orbit mode
    #------------------------------------------------------------------------

    def initialize_orbit(self):
        startz = complex(0, 0)
        self.zorbit = [startz]
        self.orbit = [self.z_to_xy(startz)]

    def start_orbit(self, x, y):
        self.initialize_orbit()
        self.iter = 1
        self.zc = self.xy_to_z(x, y)

    def update_orbit(self):
        z = self.zorbit[-1]
        zn = z*z + self.zc
        if np.abs(zn) <= 2 and (self.iter < self.itermax):
            self.iter += 1
            self.zorbit.append(zn)
            self.orbit.append(self.z_to_xy(zn))

    def plot_orbit(self, x, y):
        '''Plot the distance of the orbit from the initial point c'''
        c = self.zc
        plt.plot(np.abs(np.array(self.zorbit) - c))
        plt.xlabel('Iteration')
        plt.ylabel('Distance from c')
        plt.title('Distance of orbit from c = {:8.6f} + {:8.6f}i'.format(c.real, c.imag))
        plt.show()

    #------------------------------------------------------------------------
    # Distribution of iterations needed to diverge
    #------------------------------------------------------------------------

    def show_distribution(self):
        '''Plot a histogram of the image values'''
        image = self.img
        flat = np.reshape(image, HEIGHT*HEIGHT)
        plt.hist(flat, bins=self.iter + 1, fc='b', ec='b', normed='True')
        av = np.mean(flat)
        plt.xlabel('Number of iterations')
        plt.ylabel('Relative frequency')
        plt.title('Iterations needed to diverge, mean = {:d}'.format(int(av)))
        plt.show()

    #------------------------------------------------------------------------
    # Feigenbaum diagram for points on a given constant angle from the origin
    #------------------------------------------------------------------------

    def show_feigenbaum(self, x, y):
        '''Plot magnitude of points in orbit on a line from the origin'''
        max_iters = 1000
        z_pt = self.xy_to_z(x, y)
        c = np.linspace(0, z_pt, 1024)
        z = np.copy(c)
        # Initial run to get convergence to a repeating orbit
        for i in range(max_iters):
            np.multiply(z, z, out=z)
            np.add(z, c, out=z)
            # Just keep points which have not diverged
            remaining = np.abs(z) < 2.0
            z = z[remaining]
            c = c[remaining]
        # Plot the orbits
        for t in range(max_iters//5):
            plt.plot(np.abs(c), np.abs(z), 'k.', markersize=0.05)
            np.multiply(z, z, out=z)
            np.add(z, c, out=z)
            # Just keep points which have not diverged
            remaining = np.abs(z) < 2.0
            z = z[remaining]
            c = c[remaining]
        plt.xlabel('|c|')
        plt.ylabel('Distance of orbit from origin')
        plt.title('z = {:8.6f} + {:8.6f}i'.format(z_pt.real, z_pt.imag))
        plt.show()

    #------------------------------------------------------------------------
    # Event handlers
    #------------------------------------------------------------------------

    def mousedown(self, x, y, screen):
        self.current_pos = (x, y)
        if self.mode == MAGNIFY_MODE:
            self.start_magnify(x, y)
        elif self.mode == FEIGENBAUM_MODE:
            self.show_feigenbaum(x, y)
        elif self.mode == ORBIT_MODE:
            self.plot_orbit(x, y)

    def mouseup(self, x, y, screen):
        if self.mode == MAGNIFY_MODE:
            self.end_magnify(x, y)

    def mousemove(self, x, y, screen):
        if self.mode == MAGNIFY_MODE:
            self.move_magnify(x, y)
        elif self.mode == ORBIT_MODE:
            self.start_orbit(x, y)
        self.current_pos = (x, y)

    def keydown(self, key, screen):
        if key == K_a:
            # Switch to angle map
            self.map = ANGLE_MAP
            self.mode = MAGNIFY_MODE
            self.initialize_orbit()
            self.initialize_image()
        elif key == K_c:
            # Cycle between palettes
            if self.mode == CYCLE_MODE:
                if self.cycle:
                    self.cycle = False
                else:
                    self.mode = MAGNIFY_MODE
                    self.update_title()
            else:
                self.mode = CYCLE_MODE
                self.cycle = True
                self.update_title()
        elif key == K_d:
            # Switch to divergence map
            self.map = DIVERGENCE_MAP
            self.mode = MAGNIFY_MODE
            self.initialize_orbit()
            self.initialize_image()
        elif key == K_f:
            self.mode = FEIGENBAUM_MODE
            self.update_title()
        elif key == K_g:
            # Toggle the grid
            self.grid = (self.grid + 1) % 3
        elif key == K_h:
            # Switch to hankel map
            self.map = HANKEL_MAP
            self.mode = MAGNIFY_MODE
            self.initialize_orbit()
            self.initialize_image()
        elif key == K_i:
            # Internal structure - closest approach to origin
            self.map = CONVERGENCE_MAP
            self.mode = MAGNIFY_MODE
            self.initialize_orbit()
            self.initialize_image()
        elif key == K_m:
            # Switch to magnify mode
            self.mode = MAGNIFY_MODE
            self.initialize_orbit()
            self.initialize_image()
        elif key == K_n:
            # Toggle the image normalization
            self.normalize = not self.normalize
        elif key == K_o:
            # orbit mode
            self.mode = ORBIT_MODE
            self.initialize_orbit()
            self.update_title()
        elif key == K_p:
            # Orbit period - which period is the orbit closest to
            self.map = PERIOD_MAP
            self.mode = MAGNIFY_MODE
            self.initialize_orbit()
            self.initialize_image()
        elif key == K_r:
            # restart in magnify mode
            self.mode = MAGNIFY_MODE
            self.initialize_area()
            self.initialize_orbit()
            self.initialize_image()
        elif key == K_s:
            # Show the distibution of iterations needed to diverge
            self.show_distribution()
        elif key == K_t:
            # Cycle map - period of the orbit
            self.map = CYCLE_MAP
            self.mode = MAGNIFY_MODE
            self.initialize_orbit()
            self.initialize_image()
        elif key == K_SPACE:
            # Turn pause on or off
            self.paused = not self.paused
        elif key == K_UP:
            # Halve the spacing shown between levels
            self.levels *= 2
        elif key == K_DOWN:
            # Double the spacing shown between levels
            if self.levels > 1:
                self.levels /= 2
        elif key in NUMBER_KEYS:
            # Change the palette
            palettes = self.palettes
            num_palettes = len(palettes)
            key_num = NUMBER_KEYS.index(key)
            # Check we're not indexing past the end of the palettes
            if key_num < num_palettes:
                new_palette = palettes[key_num]
                self.palette = new_palette
                # Removed 6/7/2022. No longer using screen palettes
                # screen.set_palette(new_palette)

#--------------------------------------------------------------------------
#
# Pygame Initialization
#
#--------------------------------------------------------------------------

pygame.init()
screen = pygame.display.set_mode(SCREEN_SIZE, 0, 8)
main = Mandelbrot()
main.render(screen)

# main.add_palette(screen.get_palette())
pygame.display.update()

#--------------------------------------------------------------------------
#
# Main event handling loop
#
#--------------------------------------------------------------------------

on = True

while on:

    for event in pygame.event.get():
        if event.type == QUIT:
            on = False
            pygame.quit()
            sys.exit()

        if event.type == KEYDOWN:
            main.keydown(event.key, screen)

        if event.type == MOUSEBUTTONDOWN:
            x, y = pygame.mouse.get_pos()
            main.mousedown(x, y, screen)

        if event.type == MOUSEBUTTONUP:
            x, y = pygame.mouse.get_pos()
            main.mouseup(x, y, screen)

        if event.type == MOUSEMOTION:
            x, y = pygame.mouse.get_pos()
            main.mousemove(x, y, screen)

    main.update()
    main.render(screen)

    pygame.display.flip()
