import java.io.FileReader;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.FileNotFoundException;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.TreeSet;
import java.util.Scanner;
import java.util.Collections;
import java.text.DecimalFormat;

public class SpamFilter
{
    private static  DecimalFormat df = new DecimalFormat ("0.000");

    public static void train (HashMap <String, Integer> map, String type)
    {
        String input = "";
        if (type.equals ("spam"))
        {
            input = "spam";
        }
        else
        {
            input = "ham";
        }
        for (int i = 1; i <=5; i++)
        {
            input += String.valueOf(i) + ".txt";
            String line = "";
            String delimiter = "[? ,!.;]+";             
            try
            {
                FileReader reader = new FileReader (input);
                BufferedReader br = new BufferedReader (reader);

                while ((line = br.readLine()) != null)
                {
                    String words [] = line.split(delimiter);
                    for (int j = 0; j <words.length; j++)
                    {
                        if ( map.containsKey( words[j]))
                        {
                            int count = map.get( words[j]);
                            map.put( words[j], count + 1 );
                        }
                        else
                        {
                            map.put( words[j], 1 );
                        }
                    }
                }
                br.close();
            }

            catch (FileNotFoundException e)
            {
                System.out.println ("No such File Name Exists");
            }
            catch (IOException e)
            {
                System.out.println ("Cannot read file");
            }
            input = type;
        }

    }

    public static void displayMap(HashMap <String,Integer> map)
    {
        TreeSet<String> set = new TreeSet(map.keySet());
        System.out.println ("Word" + "\t\t\t" + "Frequency");
        for (String s : set)
        {
            System.out.println (s + "\t\t\t" + map.get(s));
        }
        System.out.println();
    }

    public static void displayMapInput(HashMap <String,Double> map)
    {
        TreeSet<String> set = new TreeSet(map.keySet());
        System.out.println ("Word" + "\t\t\t" + "Spamicity");
        for (String s : set)
        {
            System.out.println (s + "\t\t\t" + df.format(map.get(s)));
        }
        System.out.println();
    }

    public static double spamChance (ArrayList<Double> list)
    {
        Collections.sort(list);
        System.out.println ("Sorted values of Spamicity in the list:");
        for (int i = 0; i < list.size(); i++)
        {
            System.out.print (df.format(list.get(i)) + "  " );
        }
        System.out.println();
        double a = list.get(0);
        double b = list.get(list.size() -1);
        System.out.println ("a = " + a +  " b = " + b);
        double chance = (a * b)/((a* b) + ((1-a) * (1-b)));
        return chance;
    }

    public static void main (String args [])
    {
        Scanner scan = new Scanner (System.in);
        HashMap <String, Integer> spam  = new HashMap <String, Integer>();
        HashMap <String, Integer> ham = new HashMap <String, Integer>();
        HashMap <String, Double> testWords = new HashMap <String, Double>();
        ArrayList <Double> list = new ArrayList <Double>();

        train(spam, "spam");
        train(ham, "ham");
        System.out.println ("Trained Spam Database:");
        displayMap (spam);
        System.out.println();
        System.out.println ("Trained Ham Database:");
        displayMap (ham);
        System.out.println();

        try
        {
            FileReader reader = new FileReader ("test.txt");
            BufferedReader br = new BufferedReader (reader);
            String line = "";
            String delimiter = "[? ,!.;]+";      
            double spamicity = 0;
            System.out.println ("Spamicity for each word in test.txt");
            while ((line = br.readLine()) != null)
            {
                String words [] = line.split(delimiter);
                for (int j = 0; j <words.length; j++)
                {
                    if (!spam.containsKey(words[j]) && !ham.containsKey(words[j]))
                    {
                        System.out.println ("Word = " + words[j] + "  Spam Frequency = 0") ;
                        System.out.println ("Word = " + words[j] + "  Ham Frequency = 0");
                        spamicity = 0.4;
                        System.out.println ("Word = " + words[j] + "  Spamicity = " + df.format(spamicity));
                        testWords.put(words[j], 0.4);
                    }
                    else if (!spam.containsKey(words[j]) && ham.containsKey(words[j]))
                    {
                        int hamFreq = ham.get(words[j]);
                        int spamFreq  = 0;
                        System.out.println ("Word = " + words[j] + "  Spam Frequency = " + spamFreq) ;
                        System.out.println ("Word = " + words[j] + "  Ham Frequency = " + hamFreq);
                        spamicity = (spamFreq/5.0)/ ((spamFreq/5.0) + (hamFreq/5.0));
                        if (spamicity > 1)
                        {
                            spamicity = 1.0;
                        }
                        System.out.println ("Word = " + words[j] + "  Spamicity = " + df.format(spamicity));
                        testWords.put(words[j], spamicity);
                    }
                    else if (spam.containsKey(words[j]) && !ham.containsKey(words[j]))
                    {
                        int spamFreq = spam.get(words[j]);
                        System.out.println ("Word = " + words[j] + "  Spam Frequency = " + spamFreq) ;
                        int hamFreq  = 0;
                        System.out.println ("Word = " + words[j] + "  Ham Frequency = " + hamFreq);
                        spamicity = (spamFreq/5.0)/ ((spamFreq/5.0) + (hamFreq/5.0));
                        if (spamicity > 1)
                        {
                            spamicity = 1.0;
                        }
                        System.out.println ("Word = " + words[j] + "  Spamicity = " + df.format(spamicity));
                        testWords.put(words[j], spamicity);
                    }
                    else 
                    {
                        int spamFreq = spam.get(words[j]);
                        int hamFreq = ham.get(words[j]);
                        System.out.println ("Word = " + words[j] + "  Spam Frequency = " + spamFreq);
                        System.out.println ("Word = " + words[j] + "  Ham Frequency = " + hamFreq);
                        spamicity = (spamFreq/5.0)/ ((spamFreq/5.0) + (hamFreq/5.0));

                        if (spamicity > 1)
                        {
                            spamicity = 1.0;
                        }
                        System.out.println ("Word = " + words[j] + "  Spamicity = " + df.format(spamicity));
                        testWords.put (words[j], spamicity);
                    }
                    list.add(spamicity);
                    System.out.println();
                }
            }
            br.close();
        }
        catch (FileNotFoundException e)
        {
            System.out.println ("No such File Name Exists");
        }
        catch (IOException e)
        {
            System.out.println ("Cannot read file");
        }

        System.out.println ("Spamicity for words in test.txt:");
        displayMapInput (testWords);

        System.out.println ("Probability of test.txt being spam  = " + spamChance(list));

    }
}