Naturalnym pomysłem na rozwiązanie zadania jest zastosowanie metody programowania dynamicznego. Będziemy wypełniać tabelkę, gdzie $dp[i]$ oznacza koszt budowy mostu kończącego się $i$-tym filarem.
Dla uproszczenia najpierw zabezpieczymy pieniądze na wyburzenie wszystkich filarów w wysokości $W = \sum_i w_i$. Dzięki temu będziemy mogli uwzględniać koszt niewyburzenia filaru równy $-w_i$.
Zatem $dp[1] = -w_1$, bo mamy wtedy most składający się z pierwszego filaru. Dla $i>1$ mamy prostą zależność rekurencyjną: $$dp[i] = \min_{j < i} \big( dp[j] + (h_i-h_j)^2 - w_i \big).$$
Odpowiedzią jest oczywiście $dp[n] + W$, a całość można zapisać następującym programem działającym w czasie $O(n^2)$:
const int N = 110000; int n,h[N],w[N]; ll dp[N],sumw; ll sq(int x) { return ll(x)*x; } int main() { scanf("%d",&n); REP(i,n) scanf("%d",&h[i]); REP(i,n) scanf("%d",&w[i]); REP(i,n) sumw += w[i]; dp[0] = -w[0]; for (int i=1; i<n; ++i) { dp[i] = ll(1e18); for (int j=i-1; j>=0; --j) { dp[i] = min(dp[i], dp[j] + sq(h[i] - h[j]) - w[i]); } } ll ans = dp[n-1] + sumw; printf("%lld\n", ans); }
Możemy trochę inaczej rozpisać sobie koszt budowy. Przęsło mostu między filarami $i$ oraz $j$ kosztuje $$(h_i - h_j)^2 = h_i^2 - 2h_ih_j + h_j^2.$$ Składniki $h_i^2$ będą występować w sumarycznym koszcie dwukrotnie dla każdego wybranego filaru (oprócz pierwszego i ostatniego, dla których będą występować jednokrotnie). Zatem możemy zmienić koszt budowy przęsła na $$2h_i^2 - 2h_ih_j,$$ a następnie całkowity koszt to będzie $W - w_1 + h_1^2 - h_n^2$ plus koszt wynikający z budowy przęseł.
Niech $dp[i]$ będzie kosztem budowy mostu kończącego się $i$-tym filarem. Wtedy $$dp[i] = \min_{j<i}\big( dp[j] - w_i + 2h_i^2 - 2h_ih_j \big) = \min_{j<i} \big( a_j (-2h_i) + b_j + c_i \big),$$ gdzie $a_j = dp[j]$ oraz $b_j$ są pewnymi stałymi dla $j$-tego filaru, a $c_i = -w_i + 2h_i$ jest stałą dla $i$-tego filaru.
Tak więc potrzebujemy umieć minimalizować wartości funkcji liniowych $a_j x + b_j$ dla zbioru filarów, oraz dodać taką funkcję $a_i x + b_i$ dla nowego filaru.
Poniżej kod, który nieefektywnie realizuje taką strukturę w czasie $O(n^2)$:
struct oracle { vector<pair<ll,ll>> v; ll get(ll x) { ll ans = 1e18; for (auto [a, b] : v) { ans = min(ans, a*x + b); } return ans; } void insert(ll a, ll b) { v.emplace_back(a, b); } }; int main() { scanf("%d",&n); REP(i,n) scanf("%d",&h[i]); REP(i,n) scanf("%d",&w[i]); REP(i,n) sumw += w[i]; oracle s; s.insert(h[0], sumw - w[0] + sq(h[0]) - sq(h[n-1])); for (int i=1; i<n; ++i) { cost = -w[i] + 2*sq(h[i]) + s.get(-2*h[i]); s.insert(h[i], cost); } printf("%lld\n", cost); }
Ale strukturę danych można zaimplementować, aby operacje działały w zamortyzowanym czasie $O(\log n)$. Jest to tzw. convex hull optimization, gdzie utrzymujemy górną otoczkę wypukłą zbioru równań liniowych. Poniżej przykładowa implementacja tego pomysłu:
const ll infty = 1e18; inline ll ceil_div(ll a, ll b) { assert(b > 0); if (a < 0) return -( (-a)/b ); else return (a + b-1) / b; } struct line { ll a, b; mutable ll xl; bool operator<(const line& other) const { if (max(a, other.a) == infty+1) { return xl < other.xl; } else { return make_pair(a, b) < make_pair(other.a, other.b); } } ll inters(const line& prev) const { return ceil_div(prev.b - b, a - prev.a); } bool valid(const line& prev, const line& next) const { return inters(prev) < next.inters(*this); } void update(const line& prev) const { xl = inters(prev); } }; struct envelope { set<line> s; void insert(ll a, ll b) { assert(max(abs(a), abs(b)) <= infty); auto it = s.insert({ a, b, -2*infty }).first; if (next(it) != s.end() && a == next(it)->a) { s.erase(it); return; } if (it != s.begin() && prev(it)->a == a) { s.erase(prev(it)); } if (it != s.begin() && next(it) != s.end() && !it->valid(*prev(it), *next(it))) { s.erase(it); return; } if (it != s.begin()) { auto pit = prev(it); while (pit != s.begin() && !pit->valid(*prev(pit), *it)) { s.erase(pit); pit = prev(it); } it->update(*pit); } if (next(it) != s.end()) { auto nit = next(it); while (next(nit) != s.end() && !nit->valid(*it, *next(nit))) { s.erase(nit); nit = next(it); } nit->update(*it); } } bool empty() const { return s.empty(); } void clear() { s.clear(); } pair<ll,ll> eval(ll x) const { auto it = prev(s.upper_bound({ infty+1, 0, x })); return make_pair(it->a, it->b); } }; struct oracle { envelope env; ll get(ll x) { auto [a, b] = env.eval(x); return -a*x - b; } void insert(ll a, ll b) { env.insert(-a, -b); } };