diff --git a/CHANGELOG.md b/CHANGELOG.md index afd0a77ebc8..d408cc2ad56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ This project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +### Fixed +- Fix issue where Plotly Express ignored trace-specific color sequences defined in templates via `template.data.` [[#5437](https://github.com/plotly/plotly.py/pull/5437)] + ### Updated - Speed up `validate_gantt` function [[#5386](https://github.com/plotly/plotly.py/pull/5386)], with thanks to @misrasaurabh1 for the contribution! diff --git a/plotly/express/_core.py b/plotly/express/_core.py index d2dbc84c0e7..04449a186e3 100644 --- a/plotly/express/_core.py +++ b/plotly/express/_core.py @@ -1003,7 +1003,7 @@ def one_group(x): return "" -def apply_default_cascade(args): +def apply_default_cascade(args, constructor=None): # first we apply px.defaults to unspecified args for param in defaults.__slots__: @@ -1037,9 +1037,30 @@ def apply_default_cascade(args): if args["color_continuous_scale"] is None: args["color_continuous_scale"] = sequential.Viridis + # if color_discrete_sequence not set explicitly or in px.defaults, + # see if we can defer to template. Try trace-specific colors first, + # then layout.colorway, then set reasonable defaults if "color_discrete_sequence" in args: + if args["color_discrete_sequence"] is None and constructor is not None: + if constructor == "timeline": + trace_type = "bar" + else: + trace_type = constructor().type + if trace_data_list := getattr(args["template"].data, trace_type, None): + args["color_discrete_sequence"] = [ + trace_data.marker.color + for trace_data in trace_data_list + if hasattr(trace_data, "marker") + and hasattr(trace_data.marker, "color") + ] + if not args["color_discrete_sequence"] or not any( + args["color_discrete_sequence"] + ): + args["color_discrete_sequence"] = None + # fallback to layout.colorway if trace-specific colors not available if args["color_discrete_sequence"] is None and args["template"].layout.colorway: args["color_discrete_sequence"] = args["template"].layout.colorway + # final fallback to default qualitative palette if args["color_discrete_sequence"] is None: args["color_discrete_sequence"] = qualitative.D3 @@ -2486,7 +2507,7 @@ def get_groups_and_orders(args, grouper): def make_figure(args, constructor, trace_patch=None, layout_patch=None): trace_patch = trace_patch or {} layout_patch = layout_patch or {} - apply_default_cascade(args) + apply_default_cascade(args, constructor=constructor) args = build_dataframe(args, constructor) if constructor in [go.Treemap, go.Sunburst, go.Icicle] and args["path"] is not None: diff --git a/tests/test_optional/test_px/test_px.py b/tests/test_optional/test_px/test_px.py index 3a8ddabcd0a..7b7218f07c1 100644 --- a/tests/test_optional/test_px/test_px.py +++ b/tests/test_optional/test_px/test_px.py @@ -226,6 +226,52 @@ def test_px_templates(backend): pio.templates.default = "plotly" +def test_px_templates_trace_specific_colors(backend): + import pandas as pd + + tips = px.data.tips(return_type=backend) + + # trace-specific colors: each trace type uses its own template colors + template = { + "data_histogram": [ + {"marker": {"color": "orange"}}, + {"marker": {"color": "purple"}}, + ], + "data_bar": [ + {"marker": {"color": "red"}}, + {"marker": {"color": "blue"}}, + ], + "layout_colorway": ["yellow", "green"], + } + # histogram uses histogram colors + fig = px.histogram(tips, x="total_bill", color="sex", template=template) + assert fig.data[0].marker.color == "orange" + assert fig.data[1].marker.color == "purple" + # fallback to layout.colorway when trace-specific colors don't exist + fig = px.box(tips, x="day", y="total_bill", color="sex", template=template) + assert fig.data[0].marker.color == "yellow" + assert fig.data[1].marker.color == "green" + # timeline special case (maps to bar) + df_timeline = pd.DataFrame( + { + "Task": ["Job A", "Job B"], + "Start": ["2009-01-01", "2009-03-05"], + "Finish": ["2009-02-28", "2009-04-15"], + "Resource": ["Alex", "Max"], + } + ) + fig = px.timeline( + df_timeline, + x_start="Start", + x_end="Finish", + y="Task", + color="Resource", + template=template, + ) + assert fig.data[0].marker.color == "red" + assert fig.data[1].marker.color == "blue" + + def test_px_defaults(): px.defaults.labels = dict(x="hey x") px.defaults.category_orders = dict(color=["b", "a"])