This is a DP problem on trees and I tried to solve it using the following approach.

Store a DP state DP[current][mode] = min_cost for sub-tree. mode is equal to 0 if the parent of the current node has not bought a family ticket and 1 if it has.

The recurrence is:

if(mode == 0)
DP[current][mode] = min(cost_of_single_ticket + sum of DP[child][0] for all children, cost_of_family_ticket + sum of DP[child][1] for all children)


DP[current][mode] = min(sum of DP[child][0] for all children, cost_of_family_ticket + sum of DP[child][1] for all children)

The algorithm seems right but I am getting a wrong answer. Here is my code but it is very long and confusing. If you have already solved the problem could you provide some tricky test cases?

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.StringTokenizer;

public class Main 
    private static MyScanner sc;
    private static PrintWriter out;
    private static int single_cost;
    private static int family_cost;
    private static HashMap<String, Integer> map;
    private static ArrayList<String>[] Tree;
    private static HashMap<String, Integer> times;
    private static State[][] DP;
    public static void main(String[] args)
        sc = new MyScanner();
        out = new PrintWriter(System.out);
        single_cost = sc.nextInt();
        family_cost = sc.nextInt();
        int counter = 1;
        times = new HashMap();
        map = new HashMap();
        HashMap<String, ArrayList<String>> data = new HashMap();
        int unique_id = 0;
        data = new HashMap();
        boolean flag = false;
            if(single_cost == 0 && family_cost == 0) break;
            String k = sc.nextLine();
            if(k.length() != 0)
            String[] line = k.split("\\s+");
                int nu_s = Integer.parseInt(line[0]);
                int nu_f = Integer.parseInt(line[1]);
                if(nu_s == 0 && nu_f == 0)
                    DP = new State[map.size()][2];
                    Tree = new ArrayList[map.size()];
                    for(int i = 0; i < map.size(); i++) Tree[i] = new ArrayList();
                    for(String parent : data.keySet())
                        Tree[map.get(parent)] = new ArrayList();
                        for(String child : data.get(parent)) Tree[map.get(parent)].add(child);
                      for(String parent : data.keySet())
                        //System.out.println("Adding " + parent);
                        times.put(parent, 1);
                    for(String parent : data.keySet())
                        for(String child : data.get(parent))
                            if(times.containsKey(child)) times.put(child, times.get(child) + 1);
                    ArrayList<String> roots = get_root();
                  //  System.out.println(root + " is the root.");
                    int s = 0;
                    int f = 0;
                    int c = 0;
                    for(String root : roots)
                        State t = solver(root, 0);
                      s += t.single_qty;
                     f += t.family_qty;
                      c += t.cost();
                      out.println((counter++) + ". " + s + " " + f + " " + c);
                    flag = true;
                     DP = new State[map.size()][2];
                    Tree = new ArrayList[map.size()];
                      for(int i = 0; i < map.size(); i++) Tree[i] = new ArrayList();
                    for(String parent : data.keySet())
                        Tree[map.get(parent)] = new ArrayList();
                        for(String child : data.get(parent)) Tree[map.get(parent)].add(child);
                      for(String parent : data.keySet())
                        //System.out.println("Adding " + parent);
                        times.put(parent, 1);
                    for(String parent : data.keySet())
                        for(String child : data.get(parent))
                            if(times.containsKey(child)) times.put(child, times.get(child) + 1);
                    ArrayList<String> roots = get_root();
                  //  System.out.println(root + " is the root.");
                    int s = 0;
                    int f = 0;
                    int c = 0;
                    for(String root : roots)
                        State t = solver(root, 0);
                      s += t.single_qty;
                     f += t.family_qty;
                      c += t.cost();
                      out.println((counter++) + ". " + s + " " + f + " " + c);
                      // new
                    data = new HashMap();
                    times = new HashMap();
                    map = new HashMap();
                    unique_id = 0;
                    single_cost = nu_s;
                    family_cost = nu_f;
                if(!map.containsKey(line[0])) map.put(line[0], unique_id++);
                if(!data.containsKey(line[0])) data.put(line[0], new ArrayList());
                for(int i = 1; i < line.length; i++)
                  //  System.out.println("Parent -->" + line[0]);
                    if(!map.containsKey(line[i])) map.put(line[i], unique_id++);
                  //  System.out.println("Child " + i + " -->" + line[i]);

    private static ArrayList<String> get_root()
        ArrayList<String> L = new ArrayList();
        for(String node : times.keySet())
            if(times.get(node) == 1) L.add(node);
        return L;    
    private static boolean isNumber(String a)
            int k = Integer.parseInt(a);
            return true;
        catch(Exception e) {return false;}
    private static State solver(String current, int mode)
     //   System.out.println(current + " --> " + mode);
            if(mode == 0)
                State t = new State();
                t.single_qty += 1;
               // System.out.println("At " + current + " and sending " + t.cost()+" one single qty up mode 0");
                return t;
              //  System.out.println("At " + current + " and sending 0 one single qty up mode 1");
                return new State();
            if(DP[map.get(current)][mode] != null) return DP[map.get(current)][mode];
                if(mode == 0)
                    State curr_state_one = new State();
                    curr_state_one.single_qty += 1;
                    State t;
                    for(String child : Tree[map.get(current)])
                        t = solver(child, 0);
                        curr_state_one.single_qty += t.single_qty;
                        curr_state_one.family_qty += t.family_qty;
                    State curr_state_two = new State();
                    curr_state_two.family_qty += 1;
                    for(String child : Tree[map.get(current)])
                        t = solver(child, 1);
                        curr_state_two.single_qty += t.single_qty;
                        curr_state_two.family_qty += t.family_qty;
                    DP[map.get(current)][mode] = curr_state_one.minimum(curr_state_two);
                 //   System.out.println("At " + current + " and sending "+ DP[map.get(current)][mode].cost() +" one single qty up mode 0");
                    return DP[map.get(current)][mode];
                     State curr_state_one = new State();
                     State t;
                    for(String child : Tree[map.get(current)])
                        t = solver(child, 0);
                        curr_state_one.single_qty += t.single_qty;
                        curr_state_one.family_qty += t.family_qty;
                    State curr_state_two = new State();
                    curr_state_two.family_qty += 1;
                    for(String child : Tree[map.get(current)])
                        t = solver(child, 1);
                        curr_state_two.single_qty += t.single_qty;
                        curr_state_two.family_qty += t.family_qty;
                    DP[map.get(current)][mode] = curr_state_one.minimum(curr_state_two);
                  //  System.out.println("At " + current + " and sending "+ DP[map.get(current)][mode].cost() +" one single qty up mode 1");
                    return DP[map.get(current)][mode];
    private static int max(int a, int b)
        if(a > b) return a;
        else return b;

    private static int min(int a, int b)
        if(a < b) return a;
        else return b;

    private static class State
        public int single_qty;
        public int family_qty;
        public State()
            single_qty = 0;
            family_qty = 0;
        public int cost()
            return (single_qty * single_cost) + (family_qty * family_cost); 
        public State minimum(State t)
            if(this.cost() < t.cost()) return this;
            else return t;
        public void data_out(int index)
            out.println(index + ". " + this.single_qty + " " + this.family_qty + " " + this.cost());
 public static class MyScanner 
      BufferedReader br;
      StringTokenizer st;
      public MyScanner() 
         br = new BufferedReader(new InputStreamReader(;
      String next() 
          while (st == null || !st.hasMoreElements()) 
                  st = new StringTokenizer(br.readLine());
              } catch (IOException e) 
          return st.nextToken();
      int nextInt() 
          return Integer.parseInt(next());
      long nextLong() 
          return Long.parseLong(next());
      double nextDouble() 
          return Double.parseDouble(next());
      String nextLine()
          String str = "";
	     str = br.readLine();
	  } catch (IOException e) 
	  return str;

