ucb_backtest.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import logging
  2. import math
  3. import os # For file/path operations
  4. from datetime import datetime, timedelta
  5. import ccxt
  6. import matplotlib.pyplot as plt
  7. import matplotlib
  8. import numpy as np
  9. import pandas as pd
  10. matplotlib.use('TkAgg') # Or 'Qt5Agg' if you have Qt installed
  11. # Configure logging (outputs to console)
  12. logging.basicConfig(
  13. level=logging.INFO,
  14. format='%(asctime)s - %(levelname)s - %(message)s'
  15. )
  16. logger = logging.getLogger(__name__)
  17. # Top 15 coins (as before)
  18. coins = ['BTC/USDT', 'ETH/USDT', 'BNB/USDT', 'SOL/USDT', 'XRP/USDT', 'ADA/USDT', 'DOGE/USDT',
  19. 'AVAX/USDT', 'SHIB/USDT', 'DOT/USDT', 'LINK/USDT', 'TRX/USDT', 'UNI/USDT', 'LTC/USDT']
  20. exchange = ccxt.binance({'enableRateLimit': True})
  21. # UCB Class (from previous implementation)
  22. class UCB:
  23. def __init__(self, num_arms, c=2.0):
  24. self.num_arms = num_arms
  25. self.counts = np.zeros(num_arms)
  26. self.mean_rewards = np.zeros(num_arms)
  27. self.total_pulls = 0
  28. self.c = c
  29. def select_arm(self):
  30. ucb_scores = np.zeros(self.num_arms)
  31. for i in range(self.num_arms):
  32. if self.counts[i] == 0:
  33. return i
  34. ucb_scores[i] = self.mean_rewards[i] + self.c * math.sqrt(math.log(self.total_pulls + 1) / self.counts[i])
  35. return np.argmax(ucb_scores)
  36. def update(self, arm, reward):
  37. self.counts[arm] += 1
  38. self.total_pulls += 1
  39. self.mean_rewards[arm] = (self.mean_rewards[arm] * (self.counts[arm] - 1) + reward) / self.counts[arm]
  40. # Feature Computation Functions (from previous implementation)
  41. def compute_atr(df, period=14):
  42. high_low = df['high'] - df['low']
  43. high_close = np.abs(df['high'] - df['close'].shift())
  44. low_close = np.abs(df['low'] - df['close'].shift())
  45. tr = np.maximum(high_low, high_close, low_close)
  46. atr = tr.rolling(period).mean()
  47. return atr
  48. def compute_ema(df, short=12, long=26):
  49. df['ema_short'] = df['close'].ewm(span=short, adjust=False).mean()
  50. df['ema_long'] = df['close'].ewm(span=long, adjust=False).mean()
  51. df['trend'] = np.where(df['ema_short'] > df['ema_long'], 1, -1)
  52. def compute_rewards(df):
  53. df['return'] = (df['close'] - df['open']) / df['open']
  54. df['atr'] = compute_atr(df)
  55. compute_ema(df)
  56. df['reward'] = df['return'] * df['trend'] / df['atr'].replace(0, np.nan)
  57. return df.dropna()
  58. # Fetch historical OHLCV data for a symbol (with caching)
  59. def fetch_historical_ohlcv(symbol, timeframe='1h', start_date=None, end_date=None, limit=1000, refresh=False):
  60. # Create dat folder if it doesn't exist
  61. os.makedirs('dat', exist_ok=True)
  62. # Generate safe filename with program prefix
  63. program_prefix = 'ucb_backtest'
  64. symbol_safe = symbol.replace('/', '-')
  65. start_str = start_date.strftime('%Y%m%d') if start_date else 'none'
  66. end_str = end_date.strftime('%Y%m%d') if end_date else 'none'
  67. filename = f"dat/{program_prefix}_{symbol_safe}_{timeframe}_{start_str}_{end_str}.csv"
  68. if not refresh and os.path.exists(filename):
  69. try:
  70. df = pd.read_csv(filename, index_col='timestamp', parse_dates=True)
  71. logger.info(f"Loaded cached data for {symbol} from {filename}")
  72. return df
  73. except Exception as e:
  74. logger.warning(f"Error loading cache for {symbol}: {str(e)}; fetching fresh data")
  75. # Fetch fresh data
  76. try:
  77. since = int(start_date.timestamp() * 1000) if start_date else None
  78. ohlcv = []
  79. while True:
  80. data = exchange.fetch_ohlcv(symbol, timeframe, since=since, limit=limit)
  81. if not data:
  82. break
  83. ohlcv.extend(data)
  84. since = data[-1][0] + 1 # Next batch starts after last timestamp
  85. if end_date and data[-1][0] >= int(end_date.timestamp() * 1000):
  86. break
  87. df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
  88. df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
  89. df.set_index('timestamp', inplace=True)
  90. if end_date:
  91. df = df[df.index <= end_date]
  92. logger.info(f"Fetched {len(df)} historical candles for {symbol}")
  93. # Save to cache
  94. df.to_csv(filename)
  95. logger.info(f"Saved data for {symbol} to {filename}")
  96. return df
  97. except Exception as e:
  98. logger.error(f"Error fetching historical data for {symbol}: {str(e)}")
  99. return pd.DataFrame()
  100. # Backtest function (added refresh param)
  101. def backtest_ucb(start_date, end_date, initial_capital=10000.0, refresh=False):
  102. # Fetch historical data for all coins (with caching)
  103. historical_data = {}
  104. for coin in coins:
  105. df = fetch_historical_ohlcv(coin, start_date=start_date, end_date=end_date, refresh=refresh)
  106. if not df.empty:
  107. historical_data[coin] = compute_rewards(df)
  108. if not historical_data:
  109. logger.error("No historical data available; aborting backtest")
  110. return None
  111. # Find common timestamps across all coins (align data)
  112. all_timestamps = sorted(set.intersection(*(set(df.index) for df in historical_data.values())))
  113. logger.info(f"Backtesting over {len(all_timestamps)} aligned periods")
  114. # Initialize UCB and portfolio tracking
  115. ucb = UCB(len(coins))
  116. portfolio_values = [initial_capital]
  117. selected_coins = []
  118. period_returns = []
  119. current_capital = initial_capital
  120. for i in range(1, len(all_timestamps)): # Start from second period to have prior data
  121. current_time = all_timestamps[i]
  122. prev_time = all_timestamps[i-1]
  123. # Get data slices up to current time for reward computation
  124. current_data = {}
  125. for coin in coins:
  126. if coin in historical_data:
  127. df_slice = historical_data[coin].loc[:current_time]
  128. if not df_slice.empty:
  129. current_data[coin] = df_slice
  130. if not current_data:
  131. continue
  132. # Select arm (coin)
  133. arm = ucb.select_arm()
  134. coin = coins[arm]
  135. if coin not in current_data or current_data[coin].empty:
  136. logger.warning(f"No data for selected coin {coin} at {current_time}; skipping")
  137. continue
  138. # Simulate trade: Get return from prev to current close
  139. prev_close = historical_data[coin].loc[prev_time, 'close'] if prev_time in historical_data[coin].index else None
  140. current_close = historical_data[coin].loc[current_time, 'close']
  141. if prev_close is None:
  142. continue
  143. period_return = (current_close - prev_close) / prev_close
  144. reward = current_data[coin].loc[current_time, 'reward'] # Use computed reward for UCB update
  145. # Calculate quantity (simulated: full allocation)
  146. quantity = current_capital / prev_close if prev_close != 0 else 0
  147. # Log trade details
  148. logger.info(f"Trade Executed: Code={coin}, Entry Time={prev_time}, Entry Price={prev_close:.4f}, Quantity={quantity:.4f}")
  149. logger.info(f"Trade Closed: Exit Time={current_time}, Exit Price={current_close:.4f}, Realized Return={period_return:.4f}, Reward={reward:.4f}")
  150. # Update portfolio
  151. current_capital *= (1 + period_return)
  152. portfolio_values.append(current_capital)
  153. period_returns.append(period_return)
  154. selected_coins.append(coin)
  155. # Update UCB with realized reward
  156. ucb.update(arm, reward)
  157. logger.debug(f"Period {current_time}: Selected {coin}, Return: {period_return:.4f}, Reward: {reward:.4f}, Capital: {current_capital:.2f}")
  158. # Compute performance metrics
  159. if not period_returns:
  160. logger.error("No trades executed; aborting metrics")
  161. return None
  162. total_return = (current_capital - initial_capital) / initial_capital
  163. num_periods = len(period_returns)
  164. days = (end_date - start_date).days
  165. annualized_return = (1 + total_return) ** (365 / days) - 1 if days > 0 else 0
  166. sharpe_ratio = np.mean(period_returns) / np.std(period_returns) * np.sqrt(8760) if np.std(period_returns) != 0 else 0 # Annualized, 8760 hours/year
  167. max_drawdown = np.min(np.cumprod(1 + np.array(period_returns)) / np.maximum.accumulate(np.cumprod(1 + np.array(period_returns)))) - 1
  168. results = {
  169. 'total_return': total_return,
  170. 'annualized_return': annualized_return,
  171. 'sharpe_ratio': sharpe_ratio,
  172. 'max_drawdown': max_drawdown,
  173. 'final_capital': current_capital,
  174. 'num_trades': num_periods
  175. }
  176. logger.info(f"Backtest Results: {results}")
  177. # Plot equity curve
  178. plt.figure(figsize=(10, 6))
  179. plt.plot(all_timestamps[:len(portfolio_values)], portfolio_values, label='Portfolio Value')
  180. plt.title('UCB Strategy Equity Curve')
  181. plt.xlabel('Date')
  182. plt.ylabel('Portfolio Value')
  183. plt.legend()
  184. plt.grid(True)
  185. plt.show()
  186. return results
  187. # Entry point
  188. def main():
  189. # Backtest parameters (adjust as needed)
  190. end_date = datetime.now()
  191. start_date = end_date - timedelta(days=180) # Last 1 year
  192. initial_capital = 10000.0
  193. refresh = True # Set to True to force fresh data fetch
  194. logger.info(f"Starting backtest from {start_date} to {end_date} (refresh={refresh})")
  195. backtest_ucb(start_date, end_date, initial_capital, refresh)
  196. if __name__ == "__main__":
  197. main()