ucb_backtest2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. import logging
  2. import math
  3. import os
  4. from datetime import datetime, timedelta
  5. import backtrader as bt
  6. import ccxt
  7. import numpy as np
  8. import pandas as pd
  9. import matplotlib.pyplot as plt # For equity curve plot
  10. # Configure logging (outputs to console)
  11. # logging.basicConfig(
  12. # level=logging.INFO,
  13. # format='%(asctime)s - %(levelname)s - %(message)s'
  14. # )
  15. # logger = logging.getLogger(__name__)
  16. # Configure logging to both console and file
  17. logging.basicConfig(
  18. level=logging.INFO,
  19. format='%(asctime)s - %(levelname)s - %(message)s',
  20. handlers=[
  21. logging.StreamHandler(), # Console output
  22. logging.FileHandler('ucb_backtest.log', mode='w') # Log file (overwrite each run)
  23. ]
  24. )
  25. logger = logging.getLogger(__name__)
  26. # Top 15 coins (as before)
  27. coins = ['BTC/USDT', 'ETH/USDT', 'BNB/USDT', 'SOL/USDT', 'XRP/USDT', 'ADA/USDT', 'DOGE/USDT',
  28. 'AVAX/USDT', 'SHIB/USDT', 'DOT/USDT', 'LINK/USDT', 'TRX/USDT', 'UNI/USDT', 'LTC/USDT']
  29. exchange = ccxt.binance({'enableRateLimit': True})
  30. # UCB Class (simplified for backtest; updates with rewards each bar)
  31. class UCB:
  32. def __init__(self, num_arms, c=2.0):
  33. self.num_arms = num_arms
  34. self.counts = np.zeros(num_arms)
  35. self.mean_rewards = np.zeros(num_arms)
  36. self.total_pulls = 0
  37. self.c = c
  38. def compute_scores(self):
  39. ucb_scores = np.zeros(self.num_arms)
  40. for i in range(self.num_arms):
  41. if self.counts[i] == 0:
  42. ucb_scores[i] = float('inf') # Encourage exploration
  43. else:
  44. ucb_scores[i] = self.mean_rewards[i] + self.c * math.sqrt(math.log(self.total_pulls + 1) / self.counts[i])
  45. return ucb_scores
  46. def update(self, arm, reward):
  47. self.counts[arm] += 1
  48. self.total_pulls += 1
  49. self.mean_rewards[arm] = (self.mean_rewards[arm] * (self.counts[arm] - 1) + reward) / self.counts[arm]
  50. # Feature Computation Functions (from previous; applied per coin's data)
  51. def compute_atr(df, period=14):
  52. high_low = df['high'] - df['low']
  53. high_close = np.abs(df['high'] - df['close'].shift())
  54. low_close = np.abs(df['low'] - df['close'].shift())
  55. tr = np.maximum(high_low, high_close, low_close)
  56. atr = tr.rolling(period).mean()
  57. return atr
  58. def compute_ema(df, short=12, long=26):
  59. df['ema_short'] = df['close'].ewm(span=short, adjust=False).mean()
  60. df['ema_long'] = df['close'].ewm(span=long, adjust=False).mean()
  61. df['trend'] = np.where(df['ema_short'] > df['ema_long'], 1, -1)
  62. def compute_reward(df):
  63. df['return'] = (df['close'] - df['open']) / df['open']
  64. df['atr'] = compute_atr(df)
  65. compute_ema(df)
  66. reward = df['return'] * df['trend'] / df['atr'].replace(0, np.nan)
  67. return reward.iloc[-1] if not reward.empty else 0 # Latest reward
  68. # Fetch historical OHLCV data for a symbol (with caching)
  69. def fetch_historical_ohlcv(symbol, timeframe='1h', start_date=None, end_date=None, limit=1000, refresh=False):
  70. os.makedirs('dat', exist_ok=True)
  71. program_prefix = 'ucb_backtest'
  72. symbol_safe = symbol.replace('/', '-')
  73. start_str = start_date.strftime('%Y%m%d') if start_date else 'none'
  74. end_str = end_date.strftime('%Y%m%d') if end_date else 'none'
  75. filename = f"dat/{program_prefix}_{symbol_safe}_{timeframe}_{start_str}_{end_str}.csv"
  76. if not refresh and os.path.exists(filename):
  77. try:
  78. df = pd.read_csv(filename, index_col='timestamp', parse_dates=True)
  79. logger.info(f"Loaded cached data for {symbol} from {filename}")
  80. return df
  81. except Exception as e:
  82. logger.warning(f"Error loading cache for {symbol}: {str(e)}; fetching fresh data")
  83. try:
  84. since = int(start_date.timestamp() * 1000) if start_date else None
  85. ohlcv = []
  86. while True:
  87. data = exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit)
  88. if not data:
  89. break
  90. ohlcv.extend(data)
  91. since = data[-1][0] + 1
  92. if end_date and data[-1][0] >= int(end_date.timestamp() * 1000):
  93. break
  94. df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
  95. df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
  96. df.set_index('timestamp', inplace=True)
  97. if end_date:
  98. df = df[df.index <= end_date]
  99. logger.info(f"Fetched {len(df)} historical candles for {symbol}")
  100. df.to_csv(filename)
  101. logger.info(f"Saved data for {symbol} to {filename}")
  102. return df
  103. except Exception as e:
  104. logger.error(f"Error fetching historical data for {symbol}: {str(e)}")
  105. return pd.DataFrame()
  106. # Custom Backtrader Strategy with UCB
  107. class UCBStrategy(bt.Strategy):
  108. params = (
  109. ('position_size_pct', 0.10), # Fixed 10% per position
  110. ('top_n', 10), # Select top-3 coins
  111. ('min_hold_bars', 4), # Short-term: Min 4 hours
  112. ('max_hold_bars', 36), # Mid-term: Max 24 hours
  113. ('stop_loss_pct', 0.05), # 5% stop-loss
  114. ('take_profit_pct', 0.10), # 10% take-profit
  115. ('ucb_c', 2.0), # UCB exploration param
  116. )
  117. # hold duration min 6: max 24: final port 36.9%
  118. def __init__(self):
  119. self.ucb = UCB(len(coins), self.p.ucb_c)
  120. self.position_entry_bars = {} # Track bars since entry per position
  121. self.position_entry_prices = {} # Track entry price per position
  122. self.holdings_history = [] # To store holdings by date
  123. self.portfolio_value_history = [] # To store portfolio value by date
  124. self.pnl_per_coin = {coin: 0.0 for coin in coins} # <-- Add this line
  125. self.coin_value_history = [] # List of dicts: {'datetime': ..., 'BTC/USDT': ..., ...}
  126. def next(self):
  127. # Step 1: Update UCB with latest rewards for all coins
  128. rewards = []
  129. for i, data in enumerate(self.datas):
  130. df = pd.DataFrame({
  131. 'open': [data.open[0]],
  132. 'high': [data.high[0]],
  133. 'low': [data.low[0]],
  134. 'close': [data.close[0]],
  135. }, index=[data.datetime.datetime()])
  136. reward = compute_reward(df)
  137. rewards.append(reward)
  138. self.ucb.update(i, reward) # Update UCB for every coin every bar
  139. # Step 2: Check exits for open positions
  140. for data in self.datas:
  141. coin = data._name
  142. pos = self.getposition(data).size
  143. if pos != 0:
  144. bars_held = self.position_entry_bars.get(coin, 0) + 1
  145. self.position_entry_bars[coin] = bars_held
  146. entry_price = self.position_entry_prices[coin]
  147. current_price = data.close[0]
  148. pnl_pct = (current_price - entry_price) / entry_price
  149. # Exit conditions
  150. if bars_held < self.p.min_hold_bars:
  151. continue # Enforce min hold
  152. if bars_held >= self.p.max_hold_bars or pnl_pct <= -self.p.stop_loss_pct or pnl_pct >= self.p.take_profit_pct:
  153. self.pnl_per_coin[coin] += (current_price - entry_price) * abs(pos)
  154. self.close(data)
  155. logger.info(
  156. f"Trade Closed: Code={coin}, Exit Time={data.datetime.datetime()}, "
  157. f"Exit Price={current_price:.4f}, PnL %={pnl_pct:.4f}, "
  158. f"Quantity Long={pos:.4f}, Quantity Closed={pos:.4f}"
  159. )
  160. del self.position_entry_bars[coin]
  161. del self.position_entry_prices[coin]
  162. # Step 3: Select top-N coins via UCB scores
  163. scores = self.ucb.compute_scores()
  164. top_indices = np.argsort(scores)[-self.p.top_n:]
  165. top_coins = [coins[i] for i in top_indices]
  166. # Step 4: Enter new positions if possible
  167. portfolio_value = self.broker.getvalue()
  168. cash = self.broker.getcash()
  169. for coin in top_coins:
  170. data = self.getdatabyname(coin)
  171. if self.getposition(data).size == 0 and cash >= portfolio_value * self.p.position_size_pct:
  172. price = data.close[0]
  173. size = (portfolio_value * self.p.position_size_pct) / price
  174. self.buy(data=data, size=size)
  175. self.position_entry_bars[coin] = 0
  176. self.position_entry_prices[coin] = price
  177. logger.info(
  178. f"Trade Executed: Code={coin}, Entry Time={data.datetime.datetime()}, "
  179. f"Entry Price={price:.4f}, Quantity Long={size:.4f}, Quantity Closed=0.0000"
  180. )
  181. # --- Record holdings and portfolio value ---
  182. current_datetime = self.datas[0].datetime.datetime()
  183. holdings = {data._name: self.getposition(data).size for data in self.datas}
  184. portfolio_value = self.broker.getvalue()
  185. self.holdings_history.append({'datetime': current_datetime, **holdings})
  186. self.portfolio_value_history.append({'datetime': current_datetime, 'portfolio_value': portfolio_value})
  187. coin_values = {'datetime': current_datetime}
  188. for data in self.datas:
  189. coin = data._name
  190. size = self.getposition(data).size
  191. price = data.close[0]
  192. coin_values[coin] = size * price # Market value of position
  193. self.coin_value_history.append(coin_values)
  194. def notify_order(self, order):
  195. if order.status in [order.Completed]:
  196. pass # Can add more logging if needed
  197. # Entry point
  198. def main():
  199. exchange_params = {
  200. 'binance': {
  201. 'commission': 0.001,
  202. 'slippage': 0.0,
  203. 'timeframe': '1h',
  204. },
  205. 'coinbase': {
  206. 'commission': 0.0015,
  207. 'slippage': 0.0,
  208. 'timeframe': '1h',
  209. },
  210. 'kraken': {
  211. 'commission': 0.0026,
  212. 'slippage': 0.0,
  213. 'timeframe': '1h',
  214. },
  215. # Add more exchanges as needed
  216. }
  217. # Use parameters in your setup
  218. selected_exchange = 'binance' # Set this as needed
  219. commission = exchange_params[selected_exchange]['commission']
  220. slippage = exchange_params[selected_exchange]['slippage']
  221. timeframe = exchange_params[selected_exchange]['timeframe']
  222. # Backtest parameters
  223. end_date = datetime.now()
  224. start_date = end_date - timedelta(days=15) # Last 1 year
  225. initial_capital = 10000.0
  226. refresh = False # Set to True to force fresh data fetch
  227. logger.info(f"Starting backtest from {start_date} to {end_date} (refresh={refresh})")
  228. # Fetch/load data
  229. data_feeds = {}
  230. for coin in coins:
  231. df = fetch_historical_ohlcv(coin, start_date=start_date, end_date=end_date, refresh=refresh)
  232. if not df.empty:
  233. data_feeds[coin] = bt.feeds.PandasData(dataname=df, name=coin)
  234. if not data_feeds:
  235. logger.error("No data available; aborting")
  236. return
  237. # Set up Backtrader
  238. cerebro = bt.Cerebro()
  239. for coin, feed in data_feeds.items():
  240. cerebro.adddata(feed)
  241. cerebro.addstrategy(
  242. UCBStrategy,
  243. position_size_pct=0.15,
  244. top_n=5,
  245. min_hold_bars=6,
  246. max_hold_bars=24,
  247. stop_loss_pct=0.03,
  248. take_profit_pct=0.08,
  249. ucb_c=1.5
  250. )
  251. cerebro.broker.setcash(initial_capital)
  252. cerebro.broker.setcommission(commission=0.001) # 0.1%
  253. cerebro.broker.setcommission(commission=commission)
  254. # Optionally, set slippage
  255. cerebro.broker.set_slippage_perc(slippage)
  256. # Add analyzers
  257. cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe')
  258. cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
  259. cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades')
  260. cerebro.addanalyzer(bt.analyzers.Returns, _name='returns')
  261. # Run backtest
  262. results = cerebro.run()
  263. strat = results[0]
  264. # Convert to DataFrame
  265. holdings_df = pd.DataFrame(strat.holdings_history)
  266. portfolio_df = pd.DataFrame(strat.portfolio_value_history)
  267. coin_value_df = pd.DataFrame(strat.coin_value_history)
  268. # Save to CSV
  269. holdings_df.to_csv('holdings_by_date.csv', index=False)
  270. portfolio_df.to_csv('portfolio_value_by_date.csv', index=False)
  271. # Print metrics
  272. try:
  273. logger.info(f"Final Portfolio Value: {cerebro.broker.getvalue():.2f}")
  274. except Exception as e:
  275. logger.error(f"Error logging Final Portfolio Value: {e}")
  276. try:
  277. logger.info(f"Sharpe Ratio: {strat.analyzers.sharpe.get_analysis().get('sharperatio', 0):.2f}")
  278. except Exception as e:
  279. logger.error(f"Error logging Sharpe Ratio: {e}")
  280. try:
  281. logger.info(f"Max Drawdown: {strat.analyzers.drawdown.get_analysis().max.drawdown:.2f}%")
  282. except Exception as e:
  283. logger.error(f"Error logging Max Drawdown: {e}")
  284. try:
  285. logger.info(f"Total Return: {strat.analyzers.returns.get_analysis().rtot:.4f}")
  286. except Exception as e:
  287. logger.error(f"Error logging Total Return: {e}")
  288. try:
  289. logger.info(f"Number of Trades: {strat.analyzers.trades.get_analysis().total.total}")
  290. except Exception as e:
  291. logger.error(f"Error logging Number of Trades: {e}")
  292. # Plot trades (entries/exits) on candlestick charts for each coin
  293. cerebro.plot(style='candle', iplot=False, numfigs=1) # Generates one figure with subplots per coin
  294. # Plot net P/L per coin as a bar chart
  295. pnl_data = strat.pnl_per_coin
  296. fig, ax = plt.subplots(figsize=(12, 6))
  297. ax.bar(pnl_data.keys(), pnl_data.values(), color=['green' if v > 0 else 'red' for v in pnl_data.values()])
  298. ax.set_title('Net P/L per Coin at End of Backtest')
  299. ax.set_xlabel('Coin')
  300. ax.set_ylabel('Net P/L ($)')
  301. ax.grid(True)
  302. plt.xticks(rotation=45, ha='right')
  303. plt.tight_layout()
  304. plt.show()
  305. plt.figure(figsize=(14, 7))
  306. plt.plot(coin_value_df['datetime'], coin_value_df.drop('datetime', axis=1).sum(axis=1), label='Total Portfolio Value', color='black', linewidth=2)
  307. for coin in coins:
  308. if coin in coin_value_df.columns:
  309. plt.plot(coin_value_df['datetime'], coin_value_df[coin], label=coin, alpha=0.6, linewidth=1)
  310. plt.title('Portfolio Value Over Time (with Individual Coin Holdings)')
  311. plt.xlabel('Date')
  312. plt.ylabel('Value ($)')
  313. plt.legend(loc='upper left', fontsize='small', ncol=2)
  314. plt.grid(True)
  315. plt.tight_layout()
  316. plt.show()
  317. if __name__ == "__main__":
  318. main()