trend_detect.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import numpy as np
  2. import csv
  3. def read_crypto_csv(filename):
  4. """
  5. Converts a CSV file of cryptocurrency prices into a dictionary format.
  6. Args:
  7. filename (str): Path to the CSV file
  8. Returns:
  9. dict: {coin_symbol: list_of_prices}
  10. """
  11. crypto_data = {}
  12. with open(filename, 'r') as f:
  13. reader = csv.reader(f)
  14. headers = next(reader) # Read header row
  15. # Verify CSV format (first column should be 'Symbol')
  16. if headers[0] != 'Symbol':
  17. raise ValueError("CSV format incorrect. First column should be 'Symbol'")
  18. # Process each row
  19. for row in reader:
  20. symbol = row[0]
  21. prices = [float(price) for price in row[1:]] # Convert prices to floats
  22. crypto_data[symbol] = prices
  23. return crypto_data
  24. def detect_trend(data, window_size=10, slope_threshold=0.1):
  25. """
  26. Detects upward/downward trends for multiple coins using moving average slopes.
  27. Args:
  28. data (dict): Dictionary of {coin: list_of_prices}
  29. window_size (int): Days to analyze for initial trend
  30. slope_threshold (float): Slope magnitude threshold for trend classification
  31. Returns:
  32. dict: {coin: {'trend': 'upward'/'downward'/'flat', 'slope': float}}
  33. """
  34. results = {}
  35. for coin, prices in data.items():
  36. if len(prices) < window_size:
  37. raise ValueError(f"Not enough data for {coin}. Needs at least {window_size} days")
  38. # Calculate moving average
  39. window_prices = prices[:window_size]
  40. x = np.arange(window_size)
  41. y = np.array(window_prices)
  42. # Calculate slope using linear regression
  43. numerator = window_size * np.sum(x*y) - np.sum(x) * np.sum(y)
  44. denominator = window_size * np.sum(x**2) - (np.sum(x))**2
  45. slope = numerator / denominator if denominator != 0 else 0
  46. # Classify trend
  47. if slope > slope_threshold:
  48. trend = 'upward'
  49. elif slope < -slope_threshold:
  50. trend = 'downward'
  51. else:
  52. trend = 'flat'
  53. results[coin] = {'trend': trend, 'slope': slope}
  54. return results
  55. def validate_trend(data, initial_trends, window_size=10, validation_period=5, slope_threshold=0.1):
  56. """
  57. Validates trends by analyzing subsequent price movements.
  58. Args:
  59. data (dict): Dictionary of {coin: list_of_prices}
  60. initial_trends (dict): Results from detect_trend()
  61. window_size (int): Initial detection window size
  62. validation_period (int): Days after window_size to analyze
  63. slope_threshold (float): Slope threshold for validation
  64. Returns:
  65. dict: {coin: {'initial_trend': str, 'validation_trend': str, 'result': str}}
  66. """
  67. validation_results = {}
  68. for coin, prices in data.items():
  69. if len(prices) < window_size + validation_period:
  70. raise ValueError(f"Not enough data for validation for {coin}")
  71. # Get validation window prices (days 11-16 for window_size=10, validation_period=5)
  72. validation_prices = prices[window_size:window_size+validation_period]
  73. # Calculate validation trend
  74. x_val = np.arange(validation_period)
  75. y_val = np.array(validation_prices)
  76. numerator_val = validation_period * np.sum(x_val*y_val) - np.sum(x_val) * np.sum(y_val)
  77. denominator_val = validation_period * np.sum(x_val**2) - (np.sum(x_val))**2
  78. slope_val = numerator_val / denominator_val if denominator_val != 0 else 0
  79. # Classify validation trend
  80. if slope_val > slope_threshold:
  81. val_trend = 'upward'
  82. elif slope_val < -slope_threshold:
  83. val_trend = 'downward'
  84. else:
  85. val_trend = 'flat'
  86. # Determine validation result
  87. initial = initial_trends[coin]['trend']
  88. if initial == 'flat':
  89. if val_trend != 'flat':
  90. result = f'New {val_trend} trend emerged'
  91. else:
  92. result = 'Continued flat'
  93. else:
  94. if val_trend == initial:
  95. result = 'Trend continued'
  96. elif val_trend == 'flat':
  97. result = 'Trend weakened to flat'
  98. else:
  99. result = 'Trend reversed'
  100. validation_results[coin] = {
  101. 'initial_trend': initial,
  102. 'validation_trend': val_trend,
  103. 'result': result,
  104. 'validation_slope': slope_val
  105. }
  106. return validation_results
  107. # Usage example
  108. if __name__ == '__main__':
  109. # Convert CSV to crypto_data structure
  110. crypto_data = read_crypto_csv('crypto_prices.csv')
  111. # Detect initial trends
  112. initial_results = detect_trend(crypto_data, window_size=10)
  113. # Validate trends using next 5 days
  114. validation_results = validate_trend(crypto_data, initial_results, window_size=10, validation_period=5)
  115. # Print formatted results
  116. print("Coin\t\tInitial\t\tValidation\tResult")
  117. print("-----------------------------------------------------------")
  118. for coin, result in validation_results.items():
  119. print(f"{coin}\t{result['initial_trend']:8}\t{result['validation_trend']:8}\t{result['result']}")
  120. # # Example: Print first 3 coins and their first 5 prices
  121. # for i, (coin, prices) in enumerate(crypto_data.items()):
  122. # if i >= 3:
  123. # break
  124. # print(f"{coin}: {prices[:5]}...")