#
# Genetic Algorithm Python module
#

# Copyright [2006] [alex at delandgraaf dot com]
# Licensed under the LGPL

# Usage:
# import ga
#
# ga = ga.GA()
# ga.evolve()

# Function arguments (fun defaults are used otherwise):
#
# ga.GA(population_size, gene_size, crossover_rate, mutation_rate, list_of_alleles)
# ga.evolve(number_of_generations_to_process)
#
# ga.set_fitness(your_fitness_function)

# Interesting public variables (besides the ones you can modify using arguments)
#
# ga.debug = 1   (turns on lots of output)
# ga.halt  = X.X (stop if this fitness is reached)

# Note: crossover_rate is the chance two entities will reproduce
#       mutation_rate is the chance a single entity will have an allele changed

import math, random

class GA:
    """ GA class, containing all that is required for a
    simple genetic algorithm"""
    __entities = []
    crossover_rate = 0.0
    mutation_rate = 0.0
    fitness_func = lambda x: 0
    alleles = []
    debug = 0
    halt_reached = 0
    halt = -1
    
    def __init__(self, pop = 5, gene_size = 10, crossover_rate = 0.4, mutation_rate = 0.05, alleles = [0, 1]):
        """ init the entity using some default values and a default fitness
        function """

        self.crossover_rate = crossover_rate
        self.mutation_rate = mutation_rate
        self.alleles = alleles
        self.fitness_func = self.default_fitness
        self.populate(pop, gene_size, alleles)

    def default_fitness(self, entity):
        fitness = 0.0
        for i in entity:
            if i == 1:
                fitness += 1.0
        return fitness

    def set_fitness(self, x):
        self.fitness_func = x

    def populate(self, pop, gene_size, pos_list):
        """ Populate self.entities using the arguments given """

        self.entities = []
        for i in range(pop):
            ent = []
            for j in range(gene_size):
                ent.append(pos_list[random.choice(pos_list)])
            self.entities.append(ent)

    def echo_entities(self):
        for i in self.entities:
            print str(i)

    def calc_fitness(self, entity):
        return self.fitness_func(entity)

    def echo_fitness(self):
        for i in self.entities:
            print str(i) + " -> " + str(self.calc_fitness(i))

    def avg_fitness(self):
        total = 0
        for i in self.entities:
            total += self.calc_fitness(i)
        return total / len(self.entities)

    def pdebug(self, string):
        if self.debug == 1:
            print string

    def get_max_fitness(self):
        """ Return the entity with the highest fitness value """
        max_ent = 0
        for i in range(len(self.entities)):
            if self.calc_fitness(self.entities[max_ent]) < self.calc_fitness(self.entities[i]):
                max_ent = i
        return self.entities[max_ent]

    def evolve(self, generations = 100):
        """ Process a number of generations
        Stop if:
        - number of generations = generations
        - halt variable has been set AND
          there is an entity with a fitness which has obtained or exceeded
          the halt value
        Return the list of entities, or the halt-entity if prematurely halted
          """
        
        self.halt_reached = 0
        
        for gen in range(generations):
            self.do_mutation()
            self.do_crossover()
            print "Average fitness generation " + str(gen) + ": " + str(self.avg_fitness())
            if self.debug == 1:
                self.echo_fitness()
            if self.halt >= 0:
                max_entity = self.get_max_fitness()
                fit = self.calc_fitness(max_entity)
                if fit >= halt:
                    self.halt_reached = 1
                    return [max_entity]
                
        return self.entities
              

    ### Operators ###
            
    def do_mutation(self):
        """ Mutation randomizes one of the genes of an individual,
        and replaces the worst individual with it"""
        for i in self.entities:
            if random.random() <= self.mutation_rate:
                new_entity = i[:]
                allele_nr = random.randint(0, len(new_entity) - 1)
                self.pdebug("Mutating:")
                alleles = self.alleles[:]
                alleles.remove(new_entity[allele_nr])
                new_entity[allele_nr] = random.choice(alleles)
                self.replace_worst(new_entity)

    def do_crossover(self):
        """ Randomly select two individuals, and create two new
        individuals by crossover. Replace the worst two individuals
        using the newly-born """

        if random.random() <= self.crossover_rate:
            entities = self.entities[:]
            parent_1 = random.choice(entities)
            entities.remove(parent_1)
            parent_2 = random.choice(entities)

            self.pdebug("Crossing over: " + str(parent_1) + " and " + str(parent_2))
            
            begin_switching = random.randint(0, 1)
            switch_allele = random.randint(0, len(parent_1) - 1)
#            if self.debug == 1:
#                print "Begin switching: " + str(begin_switching)
#                print "Switch at allele: " + str(switch_allele)
            for i in range(len(parent_1)):
                if (begin_switching == 1 and i < switch_allele) or (begin_switching == 0 and i >= switch_allele):
                    # Switch alleles
                    temp = parent_1[i]
                    parent_1[i] = parent_2[i]
                    parent_2[i] = temp

            # children have been created, replace two most worst entities
            self.replace_worst_two(parent_1, parent_2)
            
            

    ### Selection ###

    def replace_worst_two(self, new_entity_1, new_entity_2):
        """ Should combine this with replace_worst... """
        worst_1 = 0
        worst_2 = 0
        for i in range(len(self.entities)):
            if self.calc_fitness(self.entities[worst_1]) >= self.calc_fitness(self.entities[i]):
                worst_1 = i
                
        for i in range(len(self.entities)):
            if i == worst_1:
                continue # skip worst entity, as we need the second-most worst
            if self.calc_fitness(self.entities[worst_2]) >= self.calc_fitness(self.entities[i]):
                worst_2 = i
                
        self.pdebug("Replacing " + str(self.entities[worst_1]) + " with fitness " + str(self.calc_fitness(self.entities[worst_1])))
        self.pdebug("Using new entity " + str(new_entity_1) + " using fitness: " + str(self.calc_fitness(new_entity_1)))
        self.entities[worst_1] = new_entity_1[:]
        
        self.pdebug("Replacing " + str(self.entities[worst_2]) + " with fitness " + str(self.calc_fitness(self.entities[worst_2])))
        self.pdebug("Using new entity " + str(new_entity_2) + " using fitness: " + str(self.calc_fitness(new_entity_2)))
        self.entities[worst_2] = new_entity_2[:]

    def replace_worst(self, new_entity):
        """ Find the entity with the worst fitness,
        and replace it with new_entity given """
        
        worst = 0
        for i in range(len(self.entities)):
            if self.calc_fitness(self.entities[worst]) >= self.calc_fitness(self.entities[i]):
                worst = i
                
            self.pdebug("Replacing " + str(self.entities[worst]) + " with fitness " + str(self.calc_fitness(self.entities[worst])))
            self.pdebug("New entity fitness: " + str(self.calc_fitness(new_entity)))

        self.entities[worst] = new_entity[:]

