Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

### Fixed
- Fix issue where Plotly Express ignored trace-specific color sequences defined in templates via `template.data.<trace_type>` [[#5437](https://github.com/plotly/plotly.py/pull/5437)]

### Updated
- Speed up `validate_gantt` function [[#5386](https://github.com/plotly/plotly.py/pull/5386)], with thanks to @misrasaurabh1 for the contribution!

Expand Down
25 changes: 23 additions & 2 deletions plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def one_group(x):
return ""


def apply_default_cascade(args):
def apply_default_cascade(args, constructor=None):
# first we apply px.defaults to unspecified args

for param in defaults.__slots__:
Expand Down Expand Up @@ -1037,9 +1037,30 @@ def apply_default_cascade(args):
if args["color_continuous_scale"] is None:
args["color_continuous_scale"] = sequential.Viridis

# if color_discrete_sequence not set explicitly or in px.defaults,
# see if we can defer to template. Try trace-specific colors first,
# then layout.colorway, then set reasonable defaults
if "color_discrete_sequence" in args:
if args["color_discrete_sequence"] is None and constructor is not None:
if constructor == "timeline":
trace_type = "bar"
else:
trace_type = constructor().type
if trace_data_list := getattr(args["template"].data, trace_type, None):
args["color_discrete_sequence"] = [
trace_data.marker.color
for trace_data in trace_data_list
if hasattr(trace_data, "marker")
and hasattr(trace_data.marker, "color")
]
if not args["color_discrete_sequence"] or not any(
args["color_discrete_sequence"]
):
args["color_discrete_sequence"] = None
# fallback to layout.colorway if trace-specific colors not available
if args["color_discrete_sequence"] is None and args["template"].layout.colorway:
args["color_discrete_sequence"] = args["template"].layout.colorway
# final fallback to default qualitative palette
if args["color_discrete_sequence"] is None:
args["color_discrete_sequence"] = qualitative.D3

Expand Down Expand Up @@ -2486,7 +2507,7 @@ def get_groups_and_orders(args, grouper):
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
trace_patch = trace_patch or {}
layout_patch = layout_patch or {}
apply_default_cascade(args)
apply_default_cascade(args, constructor=constructor)

args = build_dataframe(args, constructor)
if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_optional/test_px/test_px.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,52 @@ def test_px_templates(backend):
pio.templates.default = "plotly"


def test_px_templates_trace_specific_colors(backend):
import pandas as pd

tips = px.data.tips(return_type=backend)

# trace-specific colors: each trace type uses its own template colors
template = {
"data_histogram": [
{"marker": {"color": "orange"}},
{"marker": {"color": "purple"}},
],
"data_bar": [
{"marker": {"color": "red"}},
{"marker": {"color": "blue"}},
],
"layout_colorway": ["yellow", "green"],
}
# histogram uses histogram colors
fig = px.histogram(tips, x="total_bill", color="sex", template=template)
assert fig.data[0].marker.color == "orange"
assert fig.data[1].marker.color == "purple"
# fallback to layout.colorway when trace-specific colors don't exist
fig = px.box(tips, x="day", y="total_bill", color="sex", template=template)
assert fig.data[0].marker.color == "yellow"
assert fig.data[1].marker.color == "green"
# timeline special case (maps to bar)
df_timeline = pd.DataFrame(
{
"Task": ["Job A", "Job B"],
"Start": ["2009-01-01", "2009-03-05"],
"Finish": ["2009-02-28", "2009-04-15"],
"Resource": ["Alex", "Max"],
}
)
fig = px.timeline(
df_timeline,
x_start="Start",
x_end="Finish",
y="Task",
color="Resource",
template=template,
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"


def test_px_defaults():
px.defaults.labels = dict(x="hey x")
px.defaults.category_orders = dict(color=["b", "a"])
Expand Down