import re

bag_re = re.compile(r'(?:([a-z\s]+) bags contain)?(?:\s?([0-9]) ([a-z\s]+)\sbags?[,.])')

input_file = "input.txt"


def find_bags_containing_bag(search_bag="shiny gold"):
    # Stores the bag colors that can store the bag in the key
    bag_relations = {}

    with open(input_file) as rules:
        line = rules.readline()

        while line and line != "\n":
            match = bag_re.findall(line)

            if not match:
                line = rules.readline()
                continue

            parent_bag = match[0][0]

            for sub_match in match:

                if sub_match[2] in bag_relations and parent_bag not in bag_relations[sub_match[2]]:
                    bag_relations[sub_match[2]].append(parent_bag)
                else:
                    bag_relations[sub_match[2]] = [parent_bag]

            line = rules.readline()

    explored_bags = []
    bags_to_explore = [search_bag]

    usable_bags = 0

    while len(bags_to_explore) > 0:
        current_exploration = bags_to_explore.pop(0)
        explored_bags.append(current_exploration)

        if current_exploration not in bag_relations:
            continue

        for bag in bag_relations[current_exploration]:
            if bag not in explored_bags and bag not in bags_to_explore:
                usable_bags += 1
                bags_to_explore.append(bag)

    print(f"The number of bags that can contain a {search_bag} bag is {usable_bags}")


def count_bags_contained_in_bag(search_bag="shiny gold"):
    # Stores the number of bags a bag directly contains
    bag_counts = {}
    # Stores the bags contained and how many of them
    bags_contained = {}

    with open(input_file) as rules:
        line = rules.readline()

        while line and line != "\n":
            match = bag_re.findall(line)

            if not match:
                line = rules.readline()
                continue

            parent_bag = match[0][0]

            bags_contained[parent_bag] = []
            bag_counts[parent_bag] = 0

            for sub_match in match:
                bag_counts[parent_bag] += int(sub_match[1])
                bags_contained[parent_bag].append((sub_match[2], int(sub_match[1])))

            line = rules.readline()

    explored_bags = []
    bags_to_explore = [(search_bag, 1)]

    total_bags = 0

    # For each bag, fetch the total amount of bag it contains and multiply it by the number of times
    # this bag has been counted. Then, add the contained bag to the list of bags to count while keeping track of
    # the multiplier.
    # For ex. : a bag contains two blue bags. The blue bags are added with a multiplier of two.
    # The blue bags contain each three red bags. Thus, we count six red bags and add them with a multiplier of six.
    while len(bags_to_explore) > 0:
        current_exploration = bags_to_explore.pop(0)
        explored_bags.append(current_exploration)

        if current_exploration[0] not in bag_counts:
            continue

        # Add the number of bags contained, multiplied by the number of times this has appeared in this "bag descent"
        total_bags += bag_counts[current_exploration[0]]*current_exploration[1]

        for bag in bags_contained[current_exploration[0]]:
            # Keep track of the number of bags that appear in total by multiplying the number of bags
            # by the number of times the current bag appears in this "bag descent"
            bags_to_explore.append((bag[0], current_exploration[1]*bag[1]))

    print(f"The total number of bags a {search_bag} contains {total_bags}")