This is a DP problem on trees and I tried to solve it using the following approach.

Store a DP state DP[current][mode] = min_cost for sub-tree. mode is equal to 0 if the parent of the current node has not bought a family ticket and 1 if it has.

The recurrence is:

if(mode == 0)
DP[current][mode] = min(cost_of_single_ticket + sum of DP[child][0] for all children, cost_of_family_ticket + sum of DP[child][1] for all children)

else

DP[current][mode] = min(sum of DP[child][0] for all children, cost_of_family_ticket + sum of DP[child][1] for all children)


The algorithm seems right but I am getting a wrong answer. Here is my code but it is very long and confusing. If you have already solved the problem could you provide some tricky test cases?

import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.StringTokenizer;

public class Main
{
private static MyScanner sc;
private static PrintWriter out;

private static int single_cost;
private static int family_cost;
private static HashMap<String, Integer> map;
private static ArrayList<String>[] Tree;
private static HashMap<String, Integer> times;
private static State[][] DP;
public static void main(String[] args)
{
sc = new MyScanner();
out = new PrintWriter(System.out);
single_cost = sc.nextInt();
family_cost = sc.nextInt();
int counter = 1;
times = new HashMap();
map = new HashMap();
HashMap<String, ArrayList<String>> data = new HashMap();
int unique_id = 0;
data = new HashMap();
boolean flag = false;
while(!flag)
{
if(single_cost == 0 && family_cost == 0) break;
String k = sc.nextLine();

if(k.length() != 0)
{
String[] line = k.split("\\s+");
if(isNumber(line[0]))
{
int nu_s = Integer.parseInt(line[0]);
int nu_f = Integer.parseInt(line[1]);
if(nu_s == 0 && nu_f == 0)
{
DP = new State[map.size()][2];
Tree = new ArrayList[map.size()];
for(int i = 0; i < map.size(); i++) Tree[i] = new ArrayList();
for(String parent : data.keySet())
{
Tree[map.get(parent)] = new ArrayList();
}

for(String parent : data.keySet())
{
times.put(parent, 1);

}

for(String parent : data.keySet())
{
for(String child : data.get(parent))
{
if(times.containsKey(child)) times.put(child, times.get(child) + 1);
}

}

ArrayList<String> roots = get_root();
//  System.out.println(root + " is the root.");

int s = 0;
int f = 0;
int c = 0;

for(String root : roots)
{
State t = solver(root, 0);
s += t.single_qty;
f += t.family_qty;
c += t.cost();
}

out.println((counter++) + ". " + s + " " + f + " " + c);
flag = true;
}
else
{
DP = new State[map.size()][2];
Tree = new ArrayList[map.size()];
for(int i = 0; i < map.size(); i++) Tree[i] = new ArrayList();
for(String parent : data.keySet())
{
Tree[map.get(parent)] = new ArrayList();
}

for(String parent : data.keySet())
{
times.put(parent, 1);

}

for(String parent : data.keySet())
{
for(String child : data.get(parent))
{
if(times.containsKey(child)) times.put(child, times.get(child) + 1);
}

}
ArrayList<String> roots = get_root();
//  System.out.println(root + " is the root.");

int s = 0;
int f = 0;
int c = 0;

for(String root : roots)
{
State t = solver(root, 0);
s += t.single_qty;
f += t.family_qty;
c += t.cost();
}

out.println((counter++) + ". " + s + " " + f + " " + c);
// new
data = new HashMap();
times = new HashMap();
map = new HashMap();
unique_id = 0;
single_cost = nu_s;
family_cost = nu_f;

}

}

else
{
if(!map.containsKey(line[0])) map.put(line[0], unique_id++);
if(!data.containsKey(line[0])) data.put(line[0], new ArrayList());

for(int i = 1; i < line.length; i++)
{
//  System.out.println("Parent -->" + line[0]);
if(!map.containsKey(line[i])) map.put(line[i], unique_id++);
//  System.out.println("Child " + i + " -->" + line[i]);

}

}

}

}

out.close();
}

private static ArrayList<String> get_root()
{
ArrayList<String> L = new ArrayList();
for(String node : times.keySet())
{

}
return L;
}

private static boolean isNumber(String a)
{
try
{
int k = Integer.parseInt(a);
return true;
}
catch(Exception e) {return false;}

}
private static State solver(String current, int mode)
{
//   System.out.println(current + " --> " + mode);
if(Tree[map.get(current)].isEmpty())
{
if(mode == 0)
{
State t = new State();
t.single_qty += 1;
// System.out.println("At " + current + " and sending " + t.cost()+" one single qty up mode 0");
return t;
}

else
{

//  System.out.println("At " + current + " and sending 0 one single qty up mode 1");
return new State();
}
}

else
{
if(DP[map.get(current)][mode] != null) return DP[map.get(current)][mode];
else
{
if(mode == 0)
{
State curr_state_one = new State();
curr_state_one.single_qty += 1;
State t;
for(String child : Tree[map.get(current)])
{
t = solver(child, 0);
curr_state_one.single_qty += t.single_qty;
curr_state_one.family_qty += t.family_qty;
}

State curr_state_two = new State();
curr_state_two.family_qty += 1;

for(String child : Tree[map.get(current)])
{
t = solver(child, 1);
curr_state_two.single_qty += t.single_qty;
curr_state_two.family_qty += t.family_qty;

}

DP[map.get(current)][mode] = curr_state_one.minimum(curr_state_two);
//   System.out.println("At " + current + " and sending "+ DP[map.get(current)][mode].cost() +" one single qty up mode 0");
return DP[map.get(current)][mode];
}

else
{

State curr_state_one = new State();
State t;
for(String child : Tree[map.get(current)])
{
t = solver(child, 0);
curr_state_one.single_qty += t.single_qty;
curr_state_one.family_qty += t.family_qty;
}

State curr_state_two = new State();
curr_state_two.family_qty += 1;

for(String child : Tree[map.get(current)])
{
t = solver(child, 1);
curr_state_two.single_qty += t.single_qty;
curr_state_two.family_qty += t.family_qty;

}

DP[map.get(current)][mode] = curr_state_one.minimum(curr_state_two);

//  System.out.println("At " + current + " and sending "+ DP[map.get(current)][mode].cost() +" one single qty up mode 1");
return DP[map.get(current)][mode];

}

}

}

}
private static int max(int a, int b)
{
if(a > b) return a;
else return b;
}

private static int min(int a, int b)
{
if(a < b) return a;
else return b;

}

private static class State
{
public int single_qty;
public int family_qty;

public State()
{
single_qty = 0;
family_qty = 0;
}

public int cost()
{
return (single_qty * single_cost) + (family_qty * family_cost);
}

public State minimum(State t)
{
if(this.cost() < t.cost()) return this;
else return t;
}

public void data_out(int index)
{
out.println(index + ". " + this.single_qty + " " + this.family_qty + " " + this.cost());
}
}
public static class MyScanner
{
StringTokenizer st;

public MyScanner()
{
}

String next()
{
while (st == null || !st.hasMoreElements())
{
try
{
} catch (IOException e)
{
e.printStackTrace();
}
}
return st.nextToken();
}

int nextInt()
{
return Integer.parseInt(next());
}

long nextLong()
{
return Long.parseLong(next());
}

double nextDouble()
{
return Double.parseDouble(next());
}

String nextLine()
{
String str = "";
try
{
} catch (IOException e)
{
e.printStackTrace();
}
return str;
}

}
}


