Tree Analysis#

This tutorial demonstrates how to analyze saved results from retrosynthetic tree search. We load the analysis data exported by the Retrosynthetic Planning tutorial and explore:

  • Policy performance: rule applicability rate and dead-end analysis

  • Search dynamics: when routes were discovered during the search

  • Winning rule ranks: how far from the policy’s top prediction the solution rules were

  • Tree shape: branching profile across depth levels

  • Route details: per-step breakdown of individual routes

Prerequisites

Run the Retrosynthetic Planning tutorial first to generate the tutorial_results/tree_analysis.json file.

1. Loading saved results#

[ ]:
import json
from pathlib import Path

results_path = Path("tutorial_results/tree_analysis.json")
with open(results_path) as f:
    data = json.load(f)

print(f"Target: {data['target_smiles']}")
print(f"Sections: {list(data.keys())}")

2. Search summary#

The summary section contains the same flat dictionary that to_stats_dict() returns — all key metrics in one place.

[ ]:
summary = data["summary"]

print(f"{'Metric':<30} {'Value'}")
print("-" * 50)
for key, value in summary.items():
    print(f"{key:<30} {value}")

3. Policy performance#

Rule applicability rate#

The rule applicability rate is the fraction of policy-predicted rules that actually produced valid products. This is a key indicator of how well the policy matches the chemistry:

Rate

Interpretation

> 0.5

Policy predicts mostly applicable rules — efficient search

0.2 – 0.5

Moderate — some wasted expansion effort

< 0.2

Many predicted rules fail — policy may need retraining or more data

[ ]:
s = summary
print(f"Rule applicability rate: {s['rule_applicability_rate']:.2%}")
print(f"  Rules tried:     {s['total_rules_tried']}")
print(f"  Rules succeeded: {s['total_rules_succeeded']}")
print(f"  Dead-end nodes:  {s['dead_end_nodes']}")
print()
print(f"Expansion efficiency:")
print(f"  Expansion calls:     {s['expansion_calls']}")
print(f"  Expansion successes: {s['expansion_successes']}")
if s['expansion_calls'] > 0:
    print(f"  Success rate:        {s['expansion_successes'] / s['expansion_calls']:.2%}")

Winning rule ranks#

For each solved route, the winning rule rank tells you where the successful rule was in the policy’s sorted predictions at each step. Rank 1 means the policy’s top-ranked rule led to the solution.

This answers: “Does the policy predict the correct rules at the top, or does the search have to dig deep into the predictions?”

[ ]:
ranks_info = data["winning_rule_ranks"]

if ranks_info:
    all_ranks = [step["rank"] for route in ranks_info for step in route["steps"]]
    print(f"Rank statistics across {len(ranks_info)} winning routes ({len(all_ranks)} total steps):")
    print(f"  Mean rank: {sum(all_ranks) / len(all_ranks):.1f}")
    print(f"  Max rank:  {max(all_ranks)}")
    print(f"  Rank 1 (top prediction): {all_ranks.count(1)} / {len(all_ranks)} steps ({all_ranks.count(1)/len(all_ranks):.0%})")
else:
    print("No winning routes found.")
[ ]:
if ranks_info:
    from collections import Counter

    rank_counts = Counter(all_ranks)
    max_rank_to_show = min(max(rank_counts.keys()), 20)

    print("Rank distribution:")
    for rank in range(1, max_rank_to_show + 1):
        count = rank_counts.get(rank, 0)
        bar = "#" * count
        print(f"  Rank {rank:>3d}: {bar} ({count})")
[ ]:
if ranks_info:
    # Show the first route as an example
    route = ranks_info[0]
    print(f"Example: route ending at node {route['winning_node_id']}")
    print(f"{'Step':<6} {'Node':<8} {'Rule':<8} {'Prob':<10} {'Rank'}")
    print("-" * 40)
    for i, step in enumerate(route["steps"], 1):
        print(f"{i:<6} {step['node_id']:<8} {step['rule_id']:<8} {step['prob']:<10.4f} {step['rank']}")

4. Search dynamics#

Route discovery over time#

The routes_found_at list records when each route was discovered (iteration number and wall-clock time). This reveals the search convergence pattern.

[ ]:
routes_found = data["routes_found_at"]

if routes_found:
    first_iter, first_time = routes_found[0]
    total_iter = summary["num_iter"]
    total_time = summary["search_time"]

    print(f"First solution: iteration {first_iter}/{total_iter} ({first_time:.2f}s / {total_time}s)")
    print(f"Total routes found: {len(routes_found)}")
    print()

    # Routes per quarter of iterations
    quarters = [total_iter * q // 4 for q in range(1, 5)]
    print("Routes found by iteration:")
    for q_iter in quarters:
        count = sum(1 for it, _ in routes_found if it <= q_iter)
        print(f"  By iteration {q_iter:>4d}: {count} routes")
else:
    print("No routes found during search.")
    print(f"Search ran for {summary['num_iter']} iterations in {summary['search_time']}s")

Branching profile#

The branching profile shows the mean number of children per expanded node at each depth level. This reveals how the tree fans out and whether deeper nodes have fewer expansion options.

[ ]:
profile = data["branching_profile"]

print(f"{'Depth':<8} {'Mean children':<16} {'Expanded nodes':<16} {'Visual'}")
print("-" * 60)
for depth, info in sorted(profile.items(), key=lambda x: int(x[0])):
    mean_c = info["mean_children"]
    nodes = info["nodes"]
    bar = "*" * int(mean_c)
    print(f"{depth:<8} {mean_c:<16} {nodes:<16} {bar}")

print()
print(f"Overall: max branching = {summary['max_branching_factor']}, "
      f"mean branching = {summary['mean_branching_factor']}")

5. Route details#

Each saved route includes the per-step breakdown: rule used, its policy probability, the node’s evaluation value, visit count, and whether that step’s precursors are fully solved.

[ ]:
route_details = data["route_details"]

if route_details:
    # Find the best-scoring route
    best = max(route_details, key=lambda r: r["route_score"])
    print(f"Best route (node {best['node_id']}, score={best['route_score']:.6f}, "
          f"length={best['route_length']} steps):")
    print()
    for step in best["steps"]:
        status = "solved" if step["is_solved"] else "open"
        print(f"  Depth {step['depth']}: node {step['node_id']}, rule={step['rule_id']}, "
              f"prob={step['prob']:.4f}, value={step['init_value']:.4f}, "
              f"visits={step['visits']}, {status}, {step['n_precursors']} precursors")
else:
    print("No route details available.")
[ ]:
if route_details:
    # Summary across all routes
    lengths = [r["route_length"] for r in route_details]
    scores = [r["route_score"] for r in route_details]

    print(f"Route summary ({len(route_details)} routes):")
    print(f"  Length: min={min(lengths)}, max={max(lengths)}, mean={sum(lengths)/len(lengths):.1f}")
    print(f"  Score:  min={min(scores):.6f}, max={max(scores):.6f}, mean={sum(scores)/len(scores):.6f}")

    # Route length distribution
    from collections import Counter
    length_counts = Counter(lengths)
    print("\n  Length distribution:")
    for length in sorted(length_counts.keys()):
        count = length_counts[length]
        bar = "#" * count
        print(f"    {length} steps: {bar} ({count})")

6. Batch analysis from CSV#

When running run_search or the SAScore benchmark over many targets, results are saved to tree_search_stats.csv. Here we show how to load and compare across targets.

[ ]:
import csv

csv_path = Path("tutorial_results/tree_search_stats.csv")
if csv_path.exists():
    with open(csv_path) as f:
        reader = csv.DictReader(f)
        rows = list(reader)

    print(f"Loaded {len(rows)} search results from CSV")
    print(f"Columns: {list(rows[0].keys())}")
    print()
    for row in rows[:5]:
        smiles = row.get("target_smiles", "N/A")
        solved = row.get("solved", "N/A")
        routes = row.get("num_routes", "N/A")
        rate = row.get("rule_applicability_rate", "N/A")
        print(f"  {smiles[:50]:<50} solved={solved} routes={routes} applicability={rate}")
else:
    print(f"{csv_path} not found. Run tutorial 05 first.")

7. Interpretation guide#

Diagnosing common issues#

Symptom

What to check

Possible fix

No routes found

first_solution_iteration is None

Increase max_iterations / max_time

Low applicability rate (< 20%)

rule_applicability_rate

Retrain policy with more data, check rule quality

High winning rule rank (> 5)

mean_winning_rule_rank

Policy needs more training data for these reaction types

Many dead-end nodes

dead_end_nodes vs expansion_calls

Rules may be too specific, or building blocks too limited

Late first solution

first_solution_iteration / num_iter

Try higher c_ucb for more exploration

Low branching at depth 0

branching_profile

Few rules apply to the target — may need broader rule set

Comparing policies#

When comparing two policies on the same targets:

  1. Rule applicability rate — higher is better (more efficient search)

  2. Mean winning rule rank — lower is better (policy predicts correct rules at the top)

  3. First solution iteration — lower is better (faster convergence)

  4. Number of routes — more routes give chemists more choices

  5. Best route score — higher is better (better route quality)

Summary#

Data

Source

Description

summary

tree.to_stats_dict()

Flat dict with all key metrics

branching_profile

tree.branching_profile()

Mean branching factor per depth

winning_rule_ranks

tree.winning_rule_ranks()

Rank of winning rule at each step

route_details

tree.route_details(nid)

Per-step breakdown of routes

routes_found_at

tree.stats["routes_found_at"]

Route discovery timing

All these are saved by tutorial 05 into tutorial_results/tree_analysis.json. For batch analysis, run_search and the SAScore benchmark write tree_search_stats.csv with the same summary metrics.