瀏覽代碼

long holiday portfolio changes

- include web sockets
- portfolio monitor calculations
- split portfolio_monitor.py into multiple files
- temporarily set up new Portfolio class in temp_pm.py
- coding abstract table model classes
laxaurus 9 年之前
父節點
當前提交
333c7ea80a

+ 4 - 1
src/finopt/instrument.py

@@ -5,7 +5,10 @@ from misc2.helpers import ContractHelper, dict2str
 class Symbol():
 class Symbol():
     key = None
     key = None
     
     
-
+    LAST = 4
+    BID  = 1
+    ASK  = 2
+    
     
     
     
     
     def __init__(self, contract):
     def __init__(self, contract):

+ 8 - 2
src/misc2/helpers.py

@@ -218,12 +218,18 @@ class ContractHelper(BaseHelper):
 #change strike format to 2 dp     
 #change strike format to 2 dp     
 
 
 # amend 2017/04/25
 # amend 2017/04/25
-        if contract.m_exchange == '' or contract.m_exchange == None:
+        if contract.m_exchange == None:
             try:
             try:
                 contract.m_exchange = ContractHelper.map_rules['exchange'][contract.m_symbol]
                 contract.m_exchange = ContractHelper.map_rules['exchange'][contract.m_symbol]
             except:
             except:
-                pass
+                contract.m_exchange = '' 
    
    
+        if contract.m_right == None:
+            contract.m_right = ''
+            
+        if contract.m_expiry == None:
+            contract.m_expiry = ''
+            
         s = '%s-%s-%.2f-%s-%s-%s-%s-%d' % (contract.m_symbol,
         s = '%s-%s-%.2f-%s-%s-%s-%s-%d' % (contract.m_symbol,
                                                            contract.m_expiry,
                                                            contract.m_expiry,
                                                            float(contract.m_strike),
                                                            float(contract.m_strike),

+ 2 - 2
src/rethink/analytics_engine.py

@@ -34,7 +34,7 @@ class AnalyticsEngine(AbstractGatewayListener):
         
         
     
     
     def test_oc(self, oc2):
     def test_oc(self, oc2):
-        expiry = '20170427'
+        expiry = '20170529'
         contractTuple = ('HSI', 'FUT', 'HKFE', 'HKD', expiry, 0, '')
         contractTuple = ('HSI', 'FUT', 'HKFE', 'HKD', expiry, 0, '')
         contract = ContractHelper.makeContract(contractTuple)  
         contract = ContractHelper.makeContract(contractTuple)  
         
         
@@ -59,7 +59,7 @@ class AnalyticsEngine(AbstractGatewayListener):
         
         
     
     
     def test_oc3(self, oc3):
     def test_oc3(self, oc3):
-        expiry = '20170529'
+        expiry = '20170629'
         contractTuple = ('HSI', 'FUT', 'HKFE', 'HKD', expiry, 0, '')
         contractTuple = ('HSI', 'FUT', 'HKFE', 'HKD', expiry, 0, '')
         contract = ContractHelper.makeContract(contractTuple)  
         contract = ContractHelper.makeContract(contractTuple)  
          
          

+ 277 - 2
src/rethink/portfolio_item.py

@@ -11,6 +11,7 @@ from finopt.instrument import Symbol, Option
 from rethink.option_chain import OptionsChain
 from rethink.option_chain import OptionsChain
 from rethink.tick_datastore import TickDataStore
 from rethink.tick_datastore import TickDataStore
 from numpy import average
 from numpy import average
+from rethink.table_model import AbstractTableModel
 
 
 class PortfolioRules():
 class PortfolioRules():
     rule_map = {
     rule_map = {
@@ -34,6 +35,7 @@ class PortfolioItem():
     AVERAGE_COST = 7002
     AVERAGE_COST = 7002
     POSITION_DELTA = 7003
     POSITION_DELTA = 7003
     POSITION_THETA = 7004
     POSITION_THETA = 7004
+    POSITION_GAMMA = 7009
     UNREAL_PL = 7005
     UNREAL_PL = 7005
     PERCENT_GAIN_LOSS = 7006
     PERCENT_GAIN_LOSS = 7006
     AVERAGE_PRICE = 7007
     AVERAGE_PRICE = 7007
@@ -79,9 +81,21 @@ class PortfolioItem():
     def get_port_fields(self):
     def get_port_fields(self):
         return self.port_fields
         return self.port_fields
     
     
+    def get_contract_key(self):
+        return self.contract_key
+    
+    def get_right(self):
+        return self.instrument.get_contract().m_right
+    
     def get_symbol_id(self):
     def get_symbol_id(self):
         return self.instrument.get_contract().m_symbol
         return self.instrument.get_contract().m_symbol
     
     
+    def get_expiry(self):
+        return self.instrument.get_contract().m_expiry
+    
+    def get_strike(self):
+        return self.instrument.get_contract().m_strike
+    
     def get_quantity(self):
     def get_quantity(self):
         return self.port_fields[PortfolioItem.POSITION]
         return self.port_fields[PortfolioItem.POSITION]
     
     
@@ -123,7 +137,7 @@ class PortfolioItem():
                 
                 
                 pos_delta = self.get_quantity() * self.instrument.get_tick_value(Option.DELTA) * multiplier                                
                 pos_delta = self.get_quantity() * self.instrument.get_tick_value(Option.DELTA) * multiplier                                
                 pos_theta = self.get_quantity() * self.instrument.get_tick_value(Option.THETA) * multiplier
                 pos_theta = self.get_quantity() * self.instrument.get_tick_value(Option.THETA) * multiplier
-                               
+                pos_gamma = self.get_quantity() * self.instrument.get_tick_value(Option.GAMMA) * multiplier                               
 
 
                 #(spot premium * multiplier - avgcost) * pos)
                 #(spot premium * multiplier - avgcost) * pos)
                 unreal_pl = (spot_px * multiplier - self.get_average_cost()) * self.get_quantity()
                 unreal_pl = (spot_px * multiplier - self.get_average_cost()) * self.get_quantity()
@@ -132,23 +146,27 @@ class PortfolioItem():
                                         if self.get_quantity() < 0 else \
                                         if self.get_quantity() < 0 else \
                                         (spot_px - self.get_average_cost() / multiplier) / (self.get_average_cost() / multiplier) * 100 
                                         (spot_px - self.get_average_cost() / multiplier) / (self.get_average_cost() / multiplier) * 100 
                                     
                                     
-                                     
+                average_px = self.get_average_cost() / multiplier                    
                             
                             
                             
                             
             else:
             else:
                 pos_delta = self.get_quantity() * 1.0 * \
                 pos_delta = self.get_quantity() * 1.0 * \
                                PortfolioRules.rule_map['option_structure'][self.get_symbol_id()]['multiplier'] 
                                PortfolioRules.rule_map['option_structure'][self.get_symbol_id()]['multiplier'] 
                 pos_theta = 0
                 pos_theta = 0
+                pos_gamma = 0
                 # (S - X) * pos * multiplier
                 # (S - X) * pos * multiplier
                 unreal_pl = (self.instrument.get_tick_value(4) - self.get_average_cost() ) * self.get_quantity() * \
                 unreal_pl = (self.instrument.get_tick_value(4) - self.get_average_cost() ) * self.get_quantity() * \
                                PortfolioRules.rule_map['option_structure'][self.get_symbol_id()]['multiplier']
                                PortfolioRules.rule_map['option_structure'][self.get_symbol_id()]['multiplier']
                                
                                
                 sign = abs(self.get_quantity()) / self.get_quantity()                                
                 sign = abs(self.get_quantity()) / self.get_quantity()                                
                 percent_gain_loss = sign * (spot_px - self.get_average_cost() / multiplier) / (self.get_average_cost() / multiplier) * 100
                 percent_gain_loss = sign * (spot_px - self.get_average_cost() / multiplier) / (self.get_average_cost() / multiplier) * 100
+                average_px = self.get_average_cost() / multiplier
                         
                         
             self.set_port_field(PortfolioItem.POSITION_DELTA, pos_delta)
             self.set_port_field(PortfolioItem.POSITION_DELTA, pos_delta)
             self.set_port_field(PortfolioItem.POSITION_THETA, pos_theta)
             self.set_port_field(PortfolioItem.POSITION_THETA, pos_theta)
+            self.set_port_field(PortfolioItem.POSITION_GAMMA, pos_gamma)
             self.set_port_field(PortfolioItem.UNREAL_PL, unreal_pl)
             self.set_port_field(PortfolioItem.UNREAL_PL, unreal_pl)
+            self.set_port_field(PortfolioItem.AVERAGE_PRICE, average_px)
             self.set_port_field(PortfolioItem.PERCENT_GAIN_LOSS, percent_gain_loss)
             self.set_port_field(PortfolioItem.PERCENT_GAIN_LOSS, percent_gain_loss)
             
             
         except Exception, err:
         except Exception, err:
@@ -164,6 +182,263 @@ class PortfolioItem():
         if extra_info:
         if extra_info:
             self.set_port_field(PortfolioItem.MARKET_VALUE, extra_info['market_value'])
             self.set_port_field(PortfolioItem.MARKET_VALUE, extra_info['market_value'])
         
         
+        
+        
     def dump(self):
     def dump(self):
         s= ", ".join('[%s:%8.2f]' % (k, v) for k,v in self.port_fields.iteritems())
         s= ", ".join('[%s:%8.2f]' % (k, v) for k,v in self.port_fields.iteritems())
         return 'PortfolioItem contents: %s %s %s' % (self.contract_key, self.account_id, s)
         return 'PortfolioItem contents: %s %s %s' % (self.contract_key, self.account_id, s)
+    
+    
+class Portfolio(AbstractTableModel):
+    '''
+        portfolio : 
+             {
+                'port_items': {<contract_key>, PortItem}, 
+                'opt_chains': {<oc_id>: option_chain}, 
+                'g_table':{'rows':{...} , 'cols':{...}, 
+                           'header':{...},
+                           'row_index': <curr_index>,
+                           'ckey_to_row_index':{<contract_key>: {'row_id':<row_id>, 'dirty': <true/false>, 'count':0}, 
+                           'row_to_ckey_index':{<row_id>: <contract_key>}
+                                            
+             }   
+                
+    '''    
+    def __init__(self, account):
+        self.account = account
+        self.create_empty_portfolio(account)
+        
+        
+    def is_contract_in_portfolio(self, account, contract_key):
+        return self.get_portfolio_port_item(account, contract_key)
+            
+    def get_portfolio_port_item(self, account, contract_key):
+        try:
+            return self.port['port_items'][contract_key]
+        except KeyError:
+            return None
+        
+    def create_empty_portfolio(self, account):
+        self.port = {}
+        self.port['port_items']=  {}
+        self.port['opt_chains']=  {}
+        
+        
+        self.port['g_table']=  {'row_index': 0}
+        self.init_table()
+        return self.port        
+
+    
+    def set_portfolio_port_item(self, contract_key, port_item):
+        self.port['port_items'][contract_key] = port_item
+        
+        '''
+            update the gtable contract_key to row number index
+        '''
+        self.update_ckey_row_index(contract_key)
+        
+                
+    def is_oc_in_portfolio(self, oc_id):
+        try:
+            return self.port['opt_chains'][oc_id]
+        except KeyError:
+            return None
+
+    def get_option_chain(self, oc_id):
+        return self.is_oc_in_portfolio(oc_id)
+        
+    def set_option_chain(self, oc_id, oc):
+        self.port['opt_chains'][oc_id] = oc
+
+    def get_option_chains(self):
+        return self.port['opt_chains']
+
+    def calculate_item_pl(self, contract_key):
+        self.port['port_items'][contract_key].calculate_pl(contract_key)
+        
+        
+
+        
+           
+    def g_datatable_json(self):
+        dtj = {'cols':[], 'rows':[], 'ckey_to_row_index':{}}
+        header = [('symbol', 'Symbol', 'string'), ('right', 'Right', 'string'), ('avgcost', 'Avg Cost', 'number'), ('market_value', 'Market Value', 'number'), 
+                  ('avgpx', 'Avg Price', 'number'), ('spotpx', 'Spot Price', 'number'), ('pos', 'Quantity', 'number'), 
+                  ('delta', 'Delta', 'number'), ('theta', 'Theta', 'number'), ('gamma', 'Gamma', 'number'), 
+                  ('pos_delta', 'P. Delta', 'number'), ('pos_theta', 'P. Theta', 'number'), ('pos_gamma', 'P. Gamma', 'number'), 
+                  ('unreal_pl', 'Unreal P/L', 'number'), ('percent_gain_loss', '% gain/loss', 'number')  
+                  ]  
+        # header fields      
+        map(lambda hf: dtj['cols'].append({'id': hf[0], 'label': hf[1], 'type': hf[2]}), header)
+        
+        
+        def get_spot_px(x):
+            px = float('nan')
+            if x.get_quantity() > 0:
+                px= x.get_instrument().get_tick_value(Symbol.BID)
+            elif x.get_quantity() < 0:
+                px= x.get_instrument().get_tick_value(Symbol.ASK)
+            if px == -1:
+                return x.get_instrument().get_tick_value(Symbol.LAST)
+        
+        # table rows
+        def row_fields(x):
+            
+            rf = [{'v': '%s-%s-%s' % (x[1].get_symbol_id(), x[1].get_expiry(), x[1].get_strike())}, 
+                 {'v': x[1].get_right()},
+                 {'v': x[1].get_port_field(PortfolioItem.AVERAGE_COST)},
+                 {'v': x[1].get_port_field(PortfolioItem.MARKET_VALUE)},
+                 {'v': x[1].get_port_field(PortfolioItem.AVERAGE_PRICE)},
+                 {'v': get_spot_px(x[1])},
+                 {'v': x[1].get_quantity()},
+                 {'v': x[1].get_instrument().get_tick_value(Option.DELTA)},
+                 {'v': x[1].get_instrument().get_tick_value(Option.THETA)},
+                 {'v': x[1].get_instrument().get_tick_value(Option.GAMMA)},
+                 {'v': x[1].get_port_field(PortfolioItem.POSITION_DELTA)},
+                 {'v': x[1].get_port_field(PortfolioItem.POSITION_THETA)},
+                 {'v': x[1].get_port_field(PortfolioItem.POSITION_GAMMA)},
+                 {'v': x[1].get_port_field(PortfolioItem.UNREAL_PL)},
+                 {'v': x[1].get_port_field(PortfolioItem.PERCENT_GAIN_LOSS)}]
+                 
+             
+            return rf 
+        
+        
+        p_items = sorted([x for x in self.port['port_items'].iteritems()])
+        p1_items = filter(lambda x: x[1].get_symbol_id() in self.kwargs['interested_position_types']['symbol'], p_items)
+        p2_items = filter(lambda x: x[1].get_instrument_type() in self.kwargs['interested_position_types']['instrument_type'], p1_items)
+        map(lambda p: dtj['rows'].append({'c': row_fields(p)}), p2_items)
+        
+        
+        return json.dumps(dtj) #, indent=4)            
+        
+
+
+
+    def dump_portfolio(self):
+        #<account_id>: {'port_items': {<contract_key>, instrument}, 'opt_chains': {<oc_id>: option_chain}}
+        
+        def print_port_items(x):
+            return '[%s]: %s %s' % (x[0],  ', '.join('%s: %s' % (k,str(v)) for k, v in x[1].get_port_fields().iteritems()),
+                                           ', '.join('%s: %s' % (k,str(v)) for k, v in x[1].get_instrument().get_tick_values().iteritems()))
+        
+        p_items = map(print_port_items, [x for x in self.port['port_items'].iteritems()])
+        logging.info('PortfolioMonitor:dump_portfolio %s' % ('\n'.join(p_items)))
+        return '\n'.join(p_items)
+    
+    
+    
+    
+    
+    '''
+        implement AbstractTableModel methods and other routines
+    '''
+    def init_table(self):
+        self.port['g_table']['header'] = [('symbol', 'Symbol', 'string'), ('right', 'Right', 'string'), ('avgcost', 'Avg Cost', 'number'), ('market_value', 'Market Value', 'number'), 
+                  ('avgpx', 'Avg Price', 'number'), ('spotpx', 'Spot Price', 'number'), ('pos', 'Quantity', 'number'), 
+                  ('delta', 'Delta', 'number'), ('theta', 'Theta', 'number'), ('gamma', 'Gamma', 'number'), 
+                  ('pos_delta', 'P. Delta', 'number'), ('pos_theta', 'P. Theta', 'number'), ('pos_gamma', 'P. Gamma', 'number'), 
+                  ('unreal_pl', 'Unreal P/L', 'number'), ('percent_gain_loss', '% gain/loss', 'number')  
+                  ]  
+    def update_ckey_row_index(self, contract_key):
+        row_id = self.port['g_table']['row_index']
+        self.port['g_table']['ckey_to_row_index'][contract_key]['row_id'] = row_id
+        self.port['g_table']['row_to_ckey_index'][row_id] = contract_key
+        self.port['g_table']['row_index'] += 1
+  
+    def ckey_to_row(self, contract_key):
+        return self.port['g_table']['ckey_to_row_index'][contract_key]['row_id']
+  
+    def get_column_count(self):
+        return len(self.port['g_table']['header'])
+    
+    def get_row_count(self):
+        p_items = [x for x in self.port['port_items'].iteritems()]
+        p1_items = filter(lambda x: x[1].get_symbol_id() in self.kwargs['interested_position_types']['symbol'], p_items)
+        p2_items = filter(lambda x: x[1].get_instrument_type() in self.kwargs['interested_position_types']['instrument_type'], p1_items)
+        return len(p2_items)
+
+    def get_column_name(self, col):
+        return self.port['g_table']['header'][col][1]
+
+    def get_value_at(self, row, col):
+        ckey = self.port['g_table']['row_to_ckey_index'][row]
+        p_item = self.port['port_items'][ckey]
+    
+    def get_values_at(self, row):
+        ckey = self.port['g_table']['row_to_ckey_index'][row]
+        p_item = self.port['port_items'][ckey]
+        return self.port_item_to_row_fields(p_item)
+    
+    def port_item_to_row_fields(self, x):
+        rf = [{'v': '%s-%s-%s' % (x[1].get_symbol_id(), x[1].get_expiry(), x[1].get_strike())}, 
+             {'v': x[1].get_right()},
+             {'v': x[1].get_port_field(PortfolioItem.AVERAGE_COST)},
+             {'v': x[1].get_port_field(PortfolioItem.MARKET_VALUE)},
+             {'v': x[1].get_port_field(PortfolioItem.AVERAGE_PRICE)},
+             {'v': get_spot_px(x[1])},
+             {'v': x[1].get_quantity()},
+             {'v': x[1].get_instrument().get_tick_value(Option.DELTA)},
+             {'v': x[1].get_instrument().get_tick_value(Option.THETA)},
+             {'v': x[1].get_instrument().get_tick_value(Option.GAMMA)},
+             {'v': x[1].get_port_field(PortfolioItem.POSITION_DELTA)},
+             {'v': x[1].get_port_field(PortfolioItem.POSITION_THETA)},
+             {'v': x[1].get_port_field(PortfolioItem.POSITION_GAMMA)},
+             {'v': x[1].get_port_field(PortfolioItem.UNREAL_PL)},
+             {'v': x[1].get_port_field(PortfolioItem.PERCENT_GAIN_LOSS)}]
+        return rf     
+    
+    
+    def set_value_at(self, row, col, value):
+        pass
+    
+
+    def get_column_id(self, col):
+        return self.port['g_table']['header'][col][0]
+    
+    def get_JSON(self):
+        dtj = {'cols':[], 'rows':[], 'ckey_to_row_index':{}}
+        # header fields      
+        map(lambda hf: dtj['cols'].append({'id': hf[0], 'label': hf[1], 'type': hf[2]}), self.port['g_table']['header'])
+        
+        
+        def get_spot_px(x):
+            px = float('nan')
+            if x.get_quantity() > 0:
+                px= x.get_instrument().get_tick_value(Symbol.BID)
+            elif x.get_quantity() < 0:
+                px= x.get_instrument().get_tick_value(Symbol.ASK)
+            if px == -1:
+                return x.get_instrument().get_tick_value(Symbol.LAST)
+        
+        # table rows
+        def row_fields(x):
+            
+            rf = [{'v': '%s-%s-%s' % (x[1].get_symbol_id(), x[1].get_expiry(), x[1].get_strike())}, 
+                 {'v': x[1].get_right()},
+                 {'v': x[1].get_port_field(PortfolioItem.AVERAGE_COST)},
+                 {'v': x[1].get_port_field(PortfolioItem.MARKET_VALUE)},
+                 {'v': x[1].get_port_field(PortfolioItem.AVERAGE_PRICE)},
+                 {'v': get_spot_px(x[1])},
+                 {'v': x[1].get_quantity()},
+                 {'v': x[1].get_instrument().get_tick_value(Option.DELTA)},
+                 {'v': x[1].get_instrument().get_tick_value(Option.THETA)},
+                 {'v': x[1].get_instrument().get_tick_value(Option.GAMMA)},
+                 {'v': x[1].get_port_field(PortfolioItem.POSITION_DELTA)},
+                 {'v': x[1].get_port_field(PortfolioItem.POSITION_THETA)},
+                 {'v': x[1].get_port_field(PortfolioItem.POSITION_GAMMA)},
+                 {'v': x[1].get_port_field(PortfolioItem.UNREAL_PL)},
+                 {'v': x[1].get_port_field(PortfolioItem.PERCENT_GAIN_LOSS)}]
+                 
+             
+            return rf 
+        
+        
+        p_items = sorted([x for x in self.port['port_items'].iteritems()])
+        p1_items = filter(lambda x: x[1].get_symbol_id() in self.kwargs['interested_position_types']['symbol'], p_items)
+        p2_items = filter(lambda x: x[1].get_instrument_type() in self.kwargs['interested_position_types']['instrument_type'], p1_items)
+        map(lambda p: dtj['rows'].append({'c': row_fields(p)}), p2_items)
+        
+        
+        return json.dumps(dtj) #, indent=4)            
+        

二進制
src/rethink/portfolio_item.pyc


+ 96 - 13
src/rethink/portfolio_monitor.py

@@ -25,7 +25,8 @@ class PortfolioMonitor(AbstractGatewayListener):
     '''
     '''
         portfolios : 
         portfolios : 
              {
              {
-                <account_id>: {'port_items': {<contract_key>, PortItem}, 'opt_chains': {<oc_id>: option_chain}}
+                <account_id>: {'port_items': {<contract_key>, PortItem}, 'opt_chains': {<oc_id>: option_chain}, 
+                                'g_table':{'rows':{...} , 'cols':{...}, 'ckey_to_row_index':{'row_id':<row_id>, 'dirty': <true/false>, 'count':0}
              }   
              }   
                 
                 
     '''
     '''
@@ -42,7 +43,7 @@ class PortfolioMonitor(AbstractGatewayListener):
         
         
         self.portfolios = {}
         self.portfolios = {}
         
         
-    
+        
     
     
     def start_engine(self):
     def start_engine(self):
         self.twsc.start_manager()
         self.twsc.start_manager()
@@ -52,9 +53,10 @@ class PortfolioMonitor(AbstractGatewayListener):
             logging.info('PortfolioMonitor:main_loop ***** accepting console input...')
             logging.info('PortfolioMonitor:main_loop ***** accepting console input...')
             menu = {}
             menu = {}
             menu['1']="Request position" 
             menu['1']="Request position" 
-            menu['2']="Portfolio dump"
+            menu['2']="Portfolio dump dtj"
             menu['3']="TDS dump"
             menu['3']="TDS dump"
             menu['4']="Request account updates"
             menu['4']="Request account updates"
+            menu['5']="Table chart JSON"
             menu['9']="Exit"
             menu['9']="Exit"
             while True: 
             while True: 
                 choices=menu.keys()
                 choices=menu.keys()
@@ -67,13 +69,16 @@ class PortfolioMonitor(AbstractGatewayListener):
                     self.twsc.reqPositions()
                     self.twsc.reqPositions()
                 elif selection == '2': 
                 elif selection == '2': 
                     for acct in self.portfolios.keys():
                     for acct in self.portfolios.keys():
-                        print self.dump_portfolio(acct)
+                        #print self.dump_portfolio(acct)
+                        print self.portfolios[acct]['g_table']
                 elif selection == '3': 
                 elif selection == '3': 
-                    self.tds.dump()
+                    print self.tds.dump()
                 elif selection == '4': 
                 elif selection == '4': 
                     for acct in self.portfolios.keys():
                     for acct in self.portfolios.keys():
                         self.twsc.reqAccountUpdates(True, acct)
                         self.twsc.reqAccountUpdates(True, acct)
-                    
+                elif selection == '5':
+                    for acct in self.portfolios.keys():
+                        print self.g_datatable_json(acct)
                 elif selection == '9': 
                 elif selection == '9': 
                     self.twsc.gw_message_handler.set_stop()
                     self.twsc.gw_message_handler.set_stop()
                     break
                     break
@@ -88,15 +93,15 @@ class PortfolioMonitor(AbstractGatewayListener):
             logging.info('PortfolioMonitor: Service shut down complete...')               
             logging.info('PortfolioMonitor: Service shut down complete...')               
     
     
     def is_contract_in_portfolio(self, account, contract_key):
     def is_contract_in_portfolio(self, account, contract_key):
-        return self.get_portfolio_port_items(account, contract_key)
+        return self.get_portfolio_port_item(account, contract_key)
             
             
-    def get_portfolio_port_items(self, account, contract_key):
+    def get_portfolio_port_item(self, account, contract_key):
         try:
         try:
             return self.portfolios[account]['port_items'][contract_key]
             return self.portfolios[account]['port_items'][contract_key]
         except KeyError:
         except KeyError:
             return None
             return None
     
     
-    def set_portfolio_port_items(self, account, contract_key, port_item):
+    def set_portfolio_port_item(self, account, contract_key, port_item):
         self.portfolios[account]['port_items'][contract_key] = port_item
         self.portfolios[account]['port_items'][contract_key] = port_item
         
         
         
         
@@ -104,6 +109,7 @@ class PortfolioMonitor(AbstractGatewayListener):
         port = self.portfolios[account] = {}
         port = self.portfolios[account] = {}
         self.portfolios[account]['port_items']=  {}
         self.portfolios[account]['port_items']=  {}
         self.portfolios[account]['opt_chains']=  {}
         self.portfolios[account]['opt_chains']=  {}
+        self.portfolios[account]['g_table']=  {'ckey_to_row_index': {'count':0}}
         return port
         return port
                 
                 
     def get_portfolio(self, account):
     def get_portfolio(self, account):
@@ -168,7 +174,10 @@ class PortfolioMonitor(AbstractGatewayListener):
             
             
         return oc
         return oc
     
     
-    
+    def mark_gtable_row_dirty(self, account, contract_key, dirty=True):
+        self.portfolios[account]['g_table'][contract_key]['dirty'] = dirty
+        self.portfolios[account]['g_table'][contract_key]['count'] += 1
+        return self.portfolios[account]['g_table'][contract_key]['row_id']
     
     
     def process_position(self, account, contract_key, position, average_cost, extra_info=None):
     def process_position(self, account, contract_key, position, average_cost, extra_info=None):
         
         
@@ -182,6 +191,9 @@ class PortfolioMonitor(AbstractGatewayListener):
             # update the values and recalculate p/l
             # update the values and recalculate p/l
             port_item.update_position(position, average_cost, extra_info)
             port_item.update_position(position, average_cost, extra_info)
             port_item.calculate_pl(contract_key)
             port_item.calculate_pl(contract_key)
+            
+            # update the affected row in gtable as changed 
+            self.mark_gtable_row_dirty(True)
         # new position 
         # new position 
         else:
         else:
             port_item = PortfolioItem(account, contract_key, position, average_cost)
             port_item = PortfolioItem(account, contract_key, position, average_cost)
@@ -212,6 +224,8 @@ class PortfolioMonitor(AbstractGatewayListener):
                 port['port_items'][contract_key] = port_item
                 port['port_items'][contract_key] = port_item
                 
                 
             self.dump_portfolio(account)    
             self.dump_portfolio(account)    
+            
+            
         
         
     def dump_portfolio(self, account):
     def dump_portfolio(self, account):
         #<account_id>: {'port_items': {<contract_key>, instrument}, 'opt_chains': {<oc_id>: option_chain}}
         #<account_id>: {'port_items': {<contract_key>, instrument}, 'opt_chains': {<oc_id>: option_chain}}
@@ -226,6 +240,65 @@ class PortfolioMonitor(AbstractGatewayListener):
         
         
          
          
     
     
+    def g_datatable_json(self, account):
+    
+        
+         
+        
+        dtj = {'cols':[], 'rows':[], 'ckey_to_row_index':{}}
+        header = [('symbol', 'Symbol', 'string'), ('right', 'Right', 'string'), ('avgcost', 'Avg Cost', 'number'), ('market_value', 'Market Value', 'number'), 
+                  ('avgpx', 'Avg Price', 'number'), ('spotpx', 'Spot Price', 'number'), ('pos', 'Quantity', 'number'), 
+                  ('delta', 'Delta', 'number'), ('theta', 'Theta', 'number'), ('gamma', 'Gamma', 'number'), 
+                  ('pos_delta', 'P. Delta', 'number'), ('pos_theta', 'P. Theta', 'number'), ('pos_gamma', 'P. Gamma', 'number'), 
+                  ('unreal_pl', 'Unreal P/L', 'number'), ('percent_gain_loss', '% gain/loss', 'number')  
+                  ]  
+        # header fields      
+        map(lambda hf: dtj['cols'].append({'id': hf[0], 'label': hf[1], 'type': hf[2]}), header)
+        
+        
+        def get_spot_px(x):
+            px = float('nan')
+            if x.get_quantity() > 0:
+                px= x.get_instrument().get_tick_value(Symbol.BID)
+            elif x.get_quantity() < 0:
+                px= x.get_instrument().get_tick_value(Symbol.ASK)
+            if px == -1:
+                return x.get_instrument().get_tick_value(Symbol.LAST)
+        
+        # table rows
+        def row_fields(x):
+            
+            rf = [{'v': '%s-%s-%s' % (x[1].get_symbol_id(), x[1].get_expiry(), x[1].get_strike())}, 
+                 {'v': x[1].get_right()},
+                 {'v': x[1].get_port_field(PortfolioItem.AVERAGE_COST)},
+                 {'v': x[1].get_port_field(PortfolioItem.MARKET_VALUE)},
+                 {'v': x[1].get_port_field(PortfolioItem.AVERAGE_PRICE)},
+                 {'v': get_spot_px(x[1])},
+                 {'v': x[1].get_quantity()},
+                 {'v': x[1].get_instrument().get_tick_value(Option.DELTA)},
+                 {'v': x[1].get_instrument().get_tick_value(Option.THETA)},
+                 {'v': x[1].get_instrument().get_tick_value(Option.GAMMA)},
+                 {'v': x[1].get_port_field(PortfolioItem.POSITION_DELTA)},
+                 {'v': x[1].get_port_field(PortfolioItem.POSITION_THETA)},
+                 {'v': x[1].get_port_field(PortfolioItem.POSITION_GAMMA)},
+                 {'v': x[1].get_port_field(PortfolioItem.UNREAL_PL)},
+                 {'v': x[1].get_port_field(PortfolioItem.PERCENT_GAIN_LOSS)}]
+                 
+             
+            return rf 
+        
+        def set_contract_key_to_row_index(i):
+            dtj['ckey_to_row_index'][p2_items[i].get_instrument().get_contract_key()]['row_id'] = i
+            dtj['ckey_to_row_index'][p2_items[i].get_instrument().get_contract_key()]['dirty'] = False
+        
+        p_items = sorted([x for x in self.portfolios[account]['port_items'].iteritems()])
+        p1_items = filter(lambda x: x[1].get_symbol_id() in self.kwargs['interested_position_types']['symbol'], p_items)
+        p2_items = filter(lambda x: x[1].get_instrument_type() in self.kwargs['interested_position_types']['instrument_type'], p1_items)
+        map(lambda p: dtj['rows'].append({'c': row_fields(p)}), p2_items)
+        map(set_contract_key_to_row_index, range(len(p2_items)))
+        
+        return json.dumps(dtj) #, indent=4)            
+    
     #         EVENT_OPTION_UPDATED = 'oc_option_updated'
     #         EVENT_OPTION_UPDATED = 'oc_option_updated'
     #         EVENT_UNDERLYING_ADDED = 'oc_underlying_added
     #         EVENT_UNDERLYING_ADDED = 'oc_underlying_added
     def oc_option_updated(self, event, update_mode, name, instrument):        
     def oc_option_updated(self, event, update_mode, name, instrument):        
@@ -286,6 +359,12 @@ class PortfolioMonitor(AbstractGatewayListener):
                         if contract_key in self.portfolios[acct]['port_items']:
                         if contract_key in self.portfolios[acct]['port_items']:
                             self.portfolios[acct]['port_items'][contract_key].calculate_pl(key_greeks[0]) #, underlying_px)
                             self.portfolios[acct]['port_items'][contract_key].calculate_pl(key_greeks[0]) #, underlying_px)
                         
                         
+                            # dispatch pm_event to listeners
+                            self.mark_gtable_row_dirty(acct, contract_key, True)
+                            logging.info('PortfolioMonitor:tds_event_tick_updated...marking the affected row %d:[%s] as dirty' %
+                                                (self.portfolios[acct]['g_table']['dtj']['ckey_to_row_index']['row_id'], contract_key))
+                            
+                        
                     if results:
                     if results:
                         #logging.info('PortfolioMonitor:tds_event_tick_updated ....before map')
                         #logging.info('PortfolioMonitor:tds_event_tick_updated ....before map')
                         map(update_portfolio_fields, list(results.iteritems()))
                         map(update_portfolio_fields, list(results.iteritems()))
@@ -312,8 +391,9 @@ class PortfolioMonitor(AbstractGatewayListener):
 
 
 
 
     def tickSize(self, event, contract_key, field, size):
     def tickSize(self, event, contract_key, field, size):
-        self.tds.set_symbol_tick_size(contract_key, field, size)
+        #self.tds.set_symbol_tick_size(contract_key, field, size)
         #logging.info('MessageListener:%s. %s: %d %8.2f' % (event, contract_key, field, size))
         #logging.info('MessageListener:%s. %s: %d %8.2f' % (event, contract_key, field, size))
+        pass
  
  
     def position(self, event, account, contract_key, position, average_cost, end_batch):
     def position(self, event, account, contract_key, position, average_cost, end_batch):
         if not end_batch:
         if not end_batch:
@@ -325,9 +405,11 @@ class PortfolioMonitor(AbstractGatewayListener):
             # subscribe to automatic account updates
             # subscribe to automatic account updates
             if self.starting_engine:
             if self.starting_engine:
                 for acct in self.portfolios.keys():
                 for acct in self.portfolios.keys():
+                    self.portfolios[acct]['g_table'] = self.g_datatable_json(acct)
+                    logging.info('PortfolioMonitor:position. generate gtable for ac: [%s]' % acct)
                     self.twsc.reqAccountUpdates(True, acct)
                     self.twsc.reqAccountUpdates(True, acct)
                     logging.info('PortfolioMonitor:position. subscribing to auto updates for ac: [%s]' % acct)
                     logging.info('PortfolioMonitor:position. subscribing to auto updates for ac: [%s]' % acct)
-            self.start_engine = False
+            self.starting_engine = False
                     
                     
     '''
     '''
         the 4 account functions below are invoked by AbstractListener.update_portfolio_account.
         the 4 account functions below are invoked by AbstractListener.update_portfolio_account.
@@ -379,7 +461,8 @@ if __name__ == '__main__':
       'clear_offsets':  False,
       'clear_offsets':  False,
       'logconfig': {'level': logging.INFO, 'filemode': 'w', 'filename': '/tmp/pm.log'},
       'logconfig': {'level': logging.INFO, 'filemode': 'w', 'filename': '/tmp/pm.log'},
       'topics': ['position', 'positionEnd', 'tickPrice', 'update_portfolio_account'],
       'topics': ['position', 'positionEnd', 'tickPrice', 'update_portfolio_account'],
-      'seek_to_end': ['*']
+      'seek_to_end': ['*'],
+      'interested_position_types': {'symbol': ['HSI', 'MHI'], 'instrument_type': ['OPT', 'FUT']}
 
 
       
       
       }
       }

+ 40 - 0
src/rethink/table_model.py

@@ -0,0 +1,40 @@
+from misc2.observer import Subscriber, Publisher
+from misc2.observer import NotImplementedException
+import logging
+
+class AbstractTableModel(Publisher):
+    
+    EVENT_TM_TABLE_CELL_UPDATED = 'event_tm_table_cell_updated'
+    EVENT_TM_TABLE_ROWS_INSERTED = 'event_tm_table_rows_inserted'
+    EVENT_TM_TABLE_ROWS_UPDATED = 'event_tm_table_rows_updated'
+    TM_EVENTS = [EVENT_TM_TABLE_CELL_UPDATED, EVENT_TM_TABLE_ROWS_INSERTED, EVENT_TM_TABLE_ROWS_UPDATED]
+    
+    def register_listener(self, listener):
+        try:
+            map(lambda e: self.register(e, listener, getattr(listener, e)), AbstractTableModel.TM_EVENTS)
+        except AttributeError as e:
+            logging.error("AbstractTableModel:register_listener. Function not implemented in the listener. %s" % e)
+            raise NotImplementedException        
+        
+    def get_column_count(self):
+        raise NotImplementedException
+    
+    def get_row_count(self):
+        raise NotImplementedException
+
+    def get_column_name(self, col):
+        raise NotImplementedException
+
+    def get_column_id(self, col):
+        raise NotImplementedException
+
+    def get_value_at(self, row, col):
+        raise NotImplementedException
+    
+    def set_value_at(self, row, col, value):
+        raise NotImplementedException
+    
+    def insert_row(self, values):
+        raise NotImplementedException
+
+    

+ 145 - 218
src/rethink/temp_pm.py

@@ -1,90 +1,25 @@
+# -*- coding: utf-8 -*-
+import sys, traceback
 import logging
 import logging
 import json
 import json
 import time, datetime
 import time, datetime
 import copy
 import copy
 from optparse import OptionParser
 from optparse import OptionParser
 from time import sleep
 from time import sleep
-from misc2.observer import Subscriber
+from misc2.observer import Subscriber, Publisher
 from misc2.helpers import ContractHelper
 from misc2.helpers import ContractHelper
 from finopt.instrument import Symbol, Option
 from finopt.instrument import Symbol, Option
 from rethink.option_chain import OptionsChain
 from rethink.option_chain import OptionsChain
 from rethink.tick_datastore import TickDataStore
 from rethink.tick_datastore import TickDataStore
+from rethink.portfolio_item import PortfolioItem, PortfolioRules, Portfolio
 from comms.ibc.tws_client_lib import TWS_client_manager, AbstractGatewayListener
 from comms.ibc.tws_client_lib import TWS_client_manager, AbstractGatewayListener
-from numpy import average
 
 
 
 
 
 
-class PortfolioItem():
-    """
-        Set up some constant variables
-        
-        position
-        average cost
-    
-    """
-    POSITION = 6001
-    AVERAGE_COST = 6002
-    POSITION_DELTA = 6003
-    POSITION_THETA = 6004
-    UNREAL_PL = 6005
-    PERCENT_GAIN_LOSS = 6006
-    AVERAGE_PRICE = 6007
-    ACCOUNT_ID = 6008
-    
-        
-    def __init__(self, account, contract_key, position, average_cost):
-        
-        self.contract_key = contract_key
-        self.quantity = position
-        self.average_cost = average_cost
-        self.account_id = account
-        
-        contract = ContractHelper.makeContractfromRedisKeyEx(contract_key)
-        if contract.m_secType == 'OPT':
-            self.instrument = Option(contract)
-        else: 
-            self.instrument = Symbol(contract)
-        
-    
-    def get_instrument(self):
-        return self.instrument
-        
-    def get_instrument_type(self):
-        return self.instrument.get_contract().m_secType
-    
-    def get_account(self):
-        return self.account_id
-        
-    def calculate_pl(self):
-        pass
-    
-    def set_position_cost(self, position, average_cost):
-        self.quantity = position
-        self.average_cost = average_cost   
-
-
 class PortfolioMonitor(AbstractGatewayListener):
 class PortfolioMonitor(AbstractGatewayListener):
 
 
-  
-    '''
-        portfolios : 
-             {
-                <account_id>: {'port_items': {<contract_key>, instrument}, 'opt_chains': {<oc_id>: option_chain}}
-             }   
-                
-    '''
-    rule_map = {
-                'symbol': {'HSI' : 'FUT', 'MHI' : 'FUT', 'QQQ' : 'STK'},
-                'expiry': {'HSI' : 'same_month', 'MHI': 'same_month', 'STK': 'leave_blank'},
-                'option_structure': {
-                                        {'HSI':
-                                         {'spd_size': 200, 'multiplier': 50, 'rate': 0.0012, 'div': 0} 
-                                        },
-                                        {'MHI':
-                                         {'spd_size': 200, 'multiplier': 10, 'rate': 0.0012, 'div': 0} 
-                                        }
-                                    }
-               }    
+
+   
 
 
     def __init__(self, kwargs):
     def __init__(self, kwargs):
         self.kwargs = copy.copy(kwargs)
         self.kwargs = copy.copy(kwargs)
@@ -95,52 +30,25 @@ class PortfolioMonitor(AbstractGatewayListener):
         self.tds.register_listener(self)
         self.tds.register_listener(self)
         self.twsc.add_listener_topics(self, kwargs['topics'])
         self.twsc.add_listener_topics(self, kwargs['topics'])
         
         
+        '''
+            portfolios: {<account>: <portfolio>}
+        '''
         self.portfolios = {}
         self.portfolios = {}
-        self.option_chains = {}
-        
-    
-    def test_oc(self, oc2):
-        expiry = '20170427'
-        contractTuple = ('HSI', 'FUT', 'HKFE', 'HKD', expiry, 0, '')
-        contract = ContractHelper.makeContract(contractTuple)  
         
         
-        oc2.set_option_structure(contract, 200, 50, 0.0012, 0.0328, expiry)        
-        
-        oc2.build_chain(24172, 0.04, 0.22)
-        
-#         expiry='20170324'
-#         contractTuple = ('QQQ', 'STK', 'SMART', 'USD', '', 0, '')
-#         contract = ContractHelper.makeContract(contractTuple)  
-# 
-#         oc2.set_option_structure(contract, 0.5, 100, 0.0012, 0.0328, expiry)        
-#     
-#         oc2.build_chain(132.11, 0.02, 0.22)
-        
-        
-        oc2.pretty_print()        
-
-        for o in oc2.get_option_chain():
-            self.tds.add_symbol(o)
-        self.tds.add_symbol(oc2.get_underlying())
         
         
     
     
-        
-    
-    
     def start_engine(self):
     def start_engine(self):
         self.twsc.start_manager()
         self.twsc.start_manager()
-        oc2 = OptionsChain('oc2')
-        oc2.register_listener(self)
-        self.test_oc(oc2)
-        self.option_chains[oc2.name] = oc2
-        
+        self.twsc.reqPositions()
+        self.starting_engine = True
         try:
         try:
             logging.info('PortfolioMonitor:main_loop ***** accepting console input...')
             logging.info('PortfolioMonitor:main_loop ***** accepting console input...')
             menu = {}
             menu = {}
-            menu['1']="Display option chain oc2" 
-            menu['2']="Display tick data store "
-            menu['3']="Display option chain oc3"
-            menu['4']="Generate oc3 gtable json"
+            menu['1']="Request position" 
+            menu['2']="Portfolio dump dtj"
+            menu['3']="TDS dump"
+            menu['4']="Request account updates"
+            menu['5']="Table chart JSON"
             menu['9']="Exit"
             menu['9']="Exit"
             while True: 
             while True: 
                 choices=menu.keys()
                 choices=menu.keys()
@@ -150,9 +58,19 @@ class PortfolioMonitor(AbstractGatewayListener):
 
 
                 selection = raw_input("Enter command:")
                 selection = raw_input("Enter command:")
                 if selection =='1':
                 if selection =='1':
-                    oc2.pretty_print()
+                    self.twsc.reqPositions()
                 elif selection == '2': 
                 elif selection == '2': 
-                    self.tds.dump()
+                    for acct in self.portfolios.keys():
+                        #print self.dump_portfolio(acct)
+                        print self.portfolios[acct]['g_table']
+                elif selection == '3': 
+                    print self.tds.dump()
+                elif selection == '4': 
+                    for acct in self.portfolios.keys():
+                        self.twsc.reqAccountUpdates(True, acct)
+                elif selection == '5':
+                    for acct in self.portfolios.keys():
+                        print self.g_datatable_json(acct)
                 elif selection == '9': 
                 elif selection == '9': 
                     self.twsc.gw_message_handler.set_stop()
                     self.twsc.gw_message_handler.set_stop()
                     break
                     break
@@ -166,57 +84,32 @@ class PortfolioMonitor(AbstractGatewayListener):
             self.twsc.gw_message_handler.set_stop() 
             self.twsc.gw_message_handler.set_stop() 
             logging.info('PortfolioMonitor: Service shut down complete...')               
             logging.info('PortfolioMonitor: Service shut down complete...')               
     
     
-    def is_contract_in_portfolio(self, account, contract_key):
-        return self.get_portfolio_port_items(account, contract_key)
-            
-    def get_portfolio_port_items(self, account, contract_key):
-        try:
-            return self.portfolios[account]['port_items'][contract_key]
-        except KeyError:
-            return None
-    
-    def set_portfolio_port_items(self, account, contract_key, port_item):
-        self.portfolios[account]['port_items'][contract_key] = port_item
         
         
-        
-    def create_empty_portfolio(self, account):
-        port = self.portfolios[account] = {}
-        self.portfolios[account]['port_items']=  {}
-        self.portfolios[account]['opt_chains']=  {}
-        return port
                 
                 
     def get_portfolio(self, account):
     def get_portfolio(self, account):
         try:
         try:
             return self.portfolios[account]
             return self.portfolios[account]
         except KeyError:
         except KeyError:
-            self.portfolios[account] = self.create_empty_portfolio(account)
+            self.portfolios[account] = Portfolio(account)
         return self.portfolios[account]
         return self.portfolios[account]
     
     
     def deduce_option_underlying(self, option):
     def deduce_option_underlying(self, option):
         '''
         '''
             given an Option object, return the underlying Symbol object
             given an Option object, return the underlying Symbol object
         '''
         '''
-        
-
         try:
         try:
             symbol_id = option.get_contract().m_symbol
             symbol_id = option.get_contract().m_symbol
-            underlying_sectype = self.rule_map['symbol'][symbol_id]
+            underlying_sectype = PortfolioRules.rule_map['symbol'][symbol_id]
             exchange = option.get_contract().m_exchange
             exchange = option.get_contract().m_exchange
             currency = option.get_contract().m_currency
             currency = option.get_contract().m_currency
-            expiry = option.get_contract().m_expiry if self.rule_map['expiry'][symbol_id] ==  'same_month' else ''
+            expiry = option.get_contract().m_expiry if PortfolioRules.rule_map['expiry'][symbol_id] ==  'same_month' else ''
             contractTuple = (symbol_id, underlying_sectype, exchange, currency, expiry, 0, '')
             contractTuple = (symbol_id, underlying_sectype, exchange, currency, expiry, 0, '')
             logging.info('PortfolioMonitor:deduce_option_underlying. Deduced underlying==> %s' %
             logging.info('PortfolioMonitor:deduce_option_underlying. Deduced underlying==> %s' %
-                          ContractHelper.printContract(contractTuple))
+                          str(contractTuple))
             return Symbol(ContractHelper.makeContract(contractTuple))
             return Symbol(ContractHelper.makeContract(contractTuple))
         except KeyError:
         except KeyError:
             logging.error('PortfolioMonitor:deduce_option_underlying. Unable to deduce the underlying for the given option %s' %
             logging.error('PortfolioMonitor:deduce_option_underlying. Unable to deduce the underlying for the given option %s' %
                           ContractHelper.printContract(option.get_contract))
                           ContractHelper.printContract(option.get_contract))
-        
-        
-    def is_oc_in_portfolio(self, account, oc_id):
-        try:
-            return self.portfolios[account]['opt_chains'][oc_id]
-        except KeyError:
             return None
             return None
         
         
         
         
@@ -229,45 +122,40 @@ class PortfolioMonitor(AbstractGatewayListener):
         underlying_id = underlying.get_contract().m_symbol
         underlying_id = underlying.get_contract().m_symbol
         month = underlying.get_contract().m_expiry
         month = underlying.get_contract().m_expiry
         oc_id = create_oc_id(account, underlying_id, month)
         oc_id = create_oc_id(account, underlying_id, month)
-        oc = self.is_oc_in_portfolio(account, oc_id)
+        oc = self.portfolios[account].is_oc_in_portfolio(oc_id)
         if oc == None:
         if oc == None:
             oc = OptionsChain(oc_id)
             oc = OptionsChain(oc_id)
-            oc.set_option_structure(underlying,
-                                    self.rule_map['option_structure'][underlying_id]['spd_size'],
-                                    self.rule_map['option_structure'][underlying_id]['multiplier'],
-                                    self.rule_map['option_structure'][underlying_id]['rate'],
-                                    self.rule_map['option_structure'][underlying_id]['div'],
-                                    month)
+            oc.register_listener(self)
+            oc.set_option_structure(underlying.get_contract(),
+                                    PortfolioRules.rule_map['option_structure'][underlying_id]['spd_size'],
+                                    PortfolioRules.rule_map['option_structure'][underlying_id]['multiplier'],
+                                    PortfolioRules.rule_map['option_structure'][underlying_id]['rate'],
+                                    PortfolioRules.rule_map['option_structure'][underlying_id]['div'],
+                                    month,
+                                    PortfolioRules.rule_map['option_structure'][underlying_id]['trade_vol'])
             
             
-            self.portfolios[account]['opt_chains'][oc_id] = oc 
+            self.portfolios[account].set_option_chain(oc_id, oc) 
             
             
             
             
         return oc
         return oc
     
     
+
     
     
-    
-    def process_position(self, account, contract_key, position, average_cost):
+    def process_position(self, account, contract_key, position, average_cost, extra_info=None):
         
         
-        # look up the portfolio from the account code 
+        # obtain a reference to the portfolio, if not exist create a new one 
         port = self.get_portfolio(account)
         port = self.get_portfolio(account)
-        port_item = None
-        if port:
-            # look up the position in the portfolio
-            port_item =  self.is_contract_in_portfolio(account, contract_key)
-        else:
-            # create a new portfolio
-            port = self.create_empty_portfolio(account)
-
-            
-            
+        port_item =  port.is_contract_in_portfolio(contract_key)
         if port_item:
         if port_item:
             # update the values and recalculate p/l
             # update the values and recalculate p/l
-            port_item.set_position(position, average_cost)
-            port_item.calculate_pl()
+            port_item.update_position(position, average_cost, extra_info)
+            port_item.calculate_pl(contract_key)
+            
         # new position 
         # new position 
         else:
         else:
             port_item = PortfolioItem(account, contract_key, position, average_cost)
             port_item = PortfolioItem(account, contract_key, position, average_cost)
-            port['port_items'][contract_key] = port_item
+            #port['port_items'][contract_key] = port_item
+            port.set_portfolio_port_item(contract_key, port_item)
             instrument = port_item.get_instrument()
             instrument = port_item.get_instrument()
             self.tds.add_symbol(instrument)
             self.tds.add_symbol(instrument)
             self.twsc.reqMktData(instrument.get_contract(), True)
             self.twsc.reqMktData(instrument.get_contract(), True)
@@ -281,17 +169,19 @@ class PortfolioMonitor(AbstractGatewayListener):
                 underlying = self.deduce_option_underlying(instrument)
                 underlying = self.deduce_option_underlying(instrument)
                 if underlying:
                 if underlying:
                     oc = self.get_portfolio_option_chain(account, underlying)
                     oc = self.get_portfolio_option_chain(account, underlying)
-                    oc.add_option(instrument)
                     instrument.set_extra_attributes(OptionsChain.CHAIN_IDENTIFIER, oc.get_name())
                     instrument.set_extra_attributes(OptionsChain.CHAIN_IDENTIFIER, oc.get_name())
+                    oc.add_option(instrument)
                 else:
                 else:
-                    logging.error('PortfolioMonitor:process_position. Error in adding the new position %s' % contract_key)
+                    logging.error('PortfolioMonitor:process_position. **** Error in adding the new position %s' % contract_key)
             # non options. stocks, futures that is...    
             # non options. stocks, futures that is...    
             else:
             else:
-                port['port_items'][contract_key] = port_item
-                
+                logging.info('PortfolioMonitor:process_position. Adding a new non-option position into the portfolio [%s]' % port_item.dump())
+                #port['port_items'][contract_key] = port_item
+                port.set_portfolio_port_item(contract_key, port_item)
                 
                 
-        
-        
+            #self.dump_portfolio(account)    
+            port.dump_portfolio()
+            
     
     
     #         EVENT_OPTION_UPDATED = 'oc_option_updated'
     #         EVENT_OPTION_UPDATED = 'oc_option_updated'
     #         EVENT_UNDERLYING_ADDED = 'oc_underlying_added
     #         EVENT_UNDERLYING_ADDED = 'oc_underlying_added
@@ -307,16 +197,9 @@ class PortfolioMonitor(AbstractGatewayListener):
         self.tds.add_symbol(instrument)
         self.tds.add_symbol(instrument)
         self.twsc.reqMktData(instrument.get_contract(), True)
         self.twsc.reqMktData(instrument.get_contract(), True)
 
 
-    #
-    # tds call backs
-    #
-    #     
-    #         EVENT_TICK_UPDATED = 'tds_event_tick_updated'
-    #         EVENT_SYMBOL_ADDED = 'tds_event_symbol_added'
-    #         EVENT_SYMBOL_DELETED = 'tds_event_symbol_deleted'    
     
     
     def tds_event_symbol_added(self, event, update_mode, name, instrument):
     def tds_event_symbol_added(self, event, update_mode, name, instrument):
-       pass
+        pass
         #logging.info('tds_event_new_symbol_added. %s' % ContractHelper.object2kvstring(symbol.get_contract()))
         #logging.info('tds_event_new_symbol_added. %s' % ContractHelper.object2kvstring(symbol.get_contract()))
         
         
     
     
@@ -327,27 +210,48 @@ class PortfolioMonitor(AbstractGatewayListener):
             if OptionsChain.CHAIN_IDENTIFIER in s.get_extra_attributes():
             if OptionsChain.CHAIN_IDENTIFIER in s.get_extra_attributes():
                 results = {}
                 results = {}
                 chain_id = s.get_extra_attributes()[OptionsChain.CHAIN_IDENTIFIER]
                 chain_id = s.get_extra_attributes()[OptionsChain.CHAIN_IDENTIFIER]
-                logging.info('PortfolioMonitor:tds_event_tick_updated chain_id %s' % chain_id)
-                if chain_id  in self.option_chains.keys():
-                    if 'FUT' in contract_key or 'STK' in contract_key:
-                        results = self.option_chains[chain_id].cal_greeks_in_chain(self.kwargs['evaluation_date'])
-                    else:
-                        results[ContractHelper.makeRedisKeyEx(s.get_contract())] = self.option_chains[chain_id].cal_option_greeks(s, self.kwargs['evaluation_date'])
-                logging.info('AnalysticsEngine:tds_event_tick_updated. compute greek results %s' % results)    
-                # set_analytics(self, imvol=None, delta=None, gamma=None, theta=None, vega=None, npv=None):
-                # 
-                def update_tds_analytics(key_greeks):
+                #logging.info('PortfolioMonitor:tds_event_tick_updated chain_id %s' % chain_id)
+                
+                for acct in self.portfolios:
                     
                     
-                    self.tds.set_symbol_analytics(key_greeks[0], Option.IMPL_VOL, key_greeks[1][Option.IMPL_VOL])
-                    self.tds.set_symbol_analytics(key_greeks[0], Option.DELTA, key_greeks[1][Option.DELTA])
-                    self.tds.set_symbol_analytics(key_greeks[0], Option.GAMMA, key_greeks[1][Option.GAMMA])
-                    self.tds.set_symbol_analytics(key_greeks[0], Option.THETA, key_greeks[1][Option.THETA])
-                    self.tds.set_symbol_analytics(key_greeks[0], Option.VEGA, key_greeks[1][Option.VEGA])
+                    #if chain_id  in self.portfolios[acct]['opt_chains'].keys():
+                    if chain_id in self.portfolios[acct].get_option_chains():
+                        #logging.info('PortfolioMonitor:tds_event_tick_updated --> portfolio opt_chains: [  %s  ] ' % 
+                        #             str(self.portfolios[acct]['opt_chains'].keys()))
+                        if 'FUT' in contract_key or 'STK' in contract_key:
+                            #results = self.portfolios[acct]['opt_chains'][chain_id].cal_greeks_in_chain(self.kwargs['evaluation_date'])
+                            results = self.portfolios[acct].get_option_chain(chain_id).cal_greeks_in_chain(self.kwargs['evaluation_date'])
+                        else:
+                            #results[ContractHelper.makeRedisKeyEx(s.get_contract())] =  self.portfolios[acct]['opt_chains'][chain_id].cal_option_greeks(s, self.kwargs['evaluation_date'])
+                            results[ContractHelper.makeRedisKeyEx(s.get_contract())] =  self.portfolios[acct].get_option_chain(chain_id).cal_option_greeks(s, self.kwargs['evaluation_date'])
+                    #logging.info('PortfolioMonitor:tds_event_tick_updated. compute greek results %s' % results)
+                        
+                        #underlying_px = self.portfolios[acct]['opt_chains'][chain_id].get_underlying().get_tick_value(4)
+                        
+                    def update_portfolio_fields(key_greeks):
+                        
+                        self.tds.set_symbol_analytics(key_greeks[0], Option.IMPL_VOL, key_greeks[1][Option.IMPL_VOL])
+                        self.tds.set_symbol_analytics(key_greeks[0], Option.DELTA, key_greeks[1][Option.DELTA])
+                        self.tds.set_symbol_analytics(key_greeks[0], Option.GAMMA, key_greeks[1][Option.GAMMA])
+                        self.tds.set_symbol_analytics(key_greeks[0], Option.THETA, key_greeks[1][Option.THETA])
+                        self.tds.set_symbol_analytics(key_greeks[0], Option.VEGA, key_greeks[1][Option.VEGA])
+                        
+                        #if contract_key in self.portfolios[acct]['port_items']:
+                        if self.portfolios[acct].is_contract_in_portfolio(contract_key):
+                            #self.portfolios[acct]['port_items'][contract_key].calculate_pl(key_greeks[0]) #, underlying_px)
+                            self.portfolios[acct].calculate_item_pl(contract_key)
+                        
+                            
+                        
+                    if results:
+                        #logging.info('PortfolioMonitor:tds_event_tick_updated ....before map')
+                        map(update_portfolio_fields, list(results.iteritems()))
+                        #logging.info('PortfolioMonitor:tds_event_tick_updated ....after map')
+                           
+                               
                     
                     
-                map(update_tds_analytics, list(results.iteritems()))                
-
             else:
             else:
-                
+                logging.info('PortfolioMonitor:tds_event_tick_updated ignoring uninterested ticks %s' % contract_key)
                 continue
                 continue
              
              
         
         
@@ -355,42 +259,64 @@ class PortfolioMonitor(AbstractGatewayListener):
 
 
     def tds_event_symbol_deleted(self, event, update_mode, name, instrument):
     def tds_event_symbol_deleted(self, event, update_mode, name, instrument):
         pass
         pass
-    #
-    # external ae requests
-    #
-    def ae_req_greeks(self, event, message_value):
-        #(int tickerId, int field, double impliedVol, double delta, double optPrice, double pvDividend, double gamma, double vega, double theta, double undPrice) 
-        pass
-    
-    def ae_req_tds_internal(self, event, message_value):
-        logging.info('received ae_req_tds_internal')
-        self.tds.dump()
-    
-    #
-    # gateway events
-    #
+
 
 
     def tickPrice(self, event, contract_key, field, price, canAutoExecute):
     def tickPrice(self, event, contract_key, field, price, canAutoExecute):
-        logging.debug('MessageListener:%s. %s %d %8.2f' % (event, contract_key, field, price))
         self.tds.set_symbol_tick_price(contract_key, field, price, canAutoExecute)
         self.tds.set_symbol_tick_price(contract_key, field, price, canAutoExecute)
 
 
 
 
     def tickSize(self, event, contract_key, field, size):
     def tickSize(self, event, contract_key, field, size):
-        self.tds.set_symbol_tick_size(contract_key, field, size)
+        #self.tds.set_symbol_tick_size(contract_key, field, size)
         #logging.info('MessageListener:%s. %s: %d %8.2f' % (event, contract_key, field, size))
         #logging.info('MessageListener:%s. %s: %d %8.2f' % (event, contract_key, field, size))
+        pass
  
  
     def position(self, event, account, contract_key, position, average_cost, end_batch):
     def position(self, event, account, contract_key, position, average_cost, end_batch):
-        self.process_position(account, contract_key, position, average_cost)
+        if not end_batch:
+            #logging.info('PortfolioMonitor:position. received position message contract=%s' % contract_key)
+            self.process_position(account, contract_key, position, average_cost)
    
    
-    def positionEnd(self, event): #, message_value):
-        """ generated source for method positionEnd """
-        logging.info('%s [[ %s ]]' % (event, vars()))
+        else:
+            # to be run once during start up
+            # subscribe to automatic account updates
+            if self.starting_engine:
+                for acct in self.portfolios.keys():
+                    self.portfolios[acct].g_datatable_json()
+                    logging.info('PortfolioMonitor:position. generate gtable for ac: [%s]' % acct)
+                    self.twsc.reqAccountUpdates(True, acct)
+                    logging.info('PortfolioMonitor:position. subscribing to auto updates for ac: [%s]' % acct)
+            self.starting_engine = False
+                    
+    '''
+        the 4 account functions below are invoked by AbstractListener.update_portfolio_account.
+        the original message from TWS is first wrapped into update_portfolio_account event in 
+        class TWS_event_handler and then expanded by AbstractListener.update_portfolio_account
+        (check tws_event_hander)
+    '''
 
 
+                
+    def updateAccountValue(self, event, key, value, currency, account):  # key, value, currency, accountName):
+        self.raw_dump(event, vars())
  
  
+    def updatePortfolio(self, event, contract_key, position, market_price, market_value, average_cost, unrealized_PNL, realized_PNL, account):
+        self.raw_dump(event, vars())
+        self.process_position(account, contract_key, position, average_cost, 
+                              {'market_price':market_price, 'market_value':market_value, 'unrealized_PNL': unrealized_PNL, 'realized_PNL': realized_PNL})
+        
+            
+    def updateAccountTime(self, event, timestamp):
+        self.raw_dump(event, vars())
+        
+    def accountDownloadEnd(self, event, account):  # accountName):
+        self.raw_dump(event, vars())
+
  
  
     def error(self, event, message_value):
     def error(self, event, message_value):
         logging.info('PortfolioMonitor:%s. val->[%s]' % (event, message_value))         
         logging.info('PortfolioMonitor:%s. val->[%s]' % (event, message_value))         
         
         
+    def raw_dump(self, event, items):
+        del(items['self'])
+        logging.info('%s [[ %s ]]' % (event, items))      
+        
         
         
 if __name__ == '__main__':
 if __name__ == '__main__':
     
     
@@ -406,12 +332,13 @@ if __name__ == '__main__':
       'tws_host': 'localhost',
       'tws_host': 'localhost',
       'tws_api_port': 8496,
       'tws_api_port': 8496,
       'tws_app_id': 38868,
       'tws_app_id': 38868,
-      'group_id': 'AE',
+      'group_id': 'PM',
       'session_timeout_ms': 10000,
       'session_timeout_ms': 10000,
       'clear_offsets':  False,
       'clear_offsets':  False,
       'logconfig': {'level': logging.INFO, 'filemode': 'w', 'filename': '/tmp/pm.log'},
       'logconfig': {'level': logging.INFO, 'filemode': 'w', 'filename': '/tmp/pm.log'},
-      'topics': ['tickPrice', 'tickSize'],
-      'seek_to_end': ['*']
+      'topics': ['position', 'positionEnd', 'tickPrice', 'update_portfolio_account'],
+      'seek_to_end': ['*'],
+      'interested_position_types': {'symbol': ['HSI', 'MHI'], 'instrument_type': ['OPT', 'FUT']}
 
 
       
       
       }
       }

+ 15 - 0
src/sh/pm.sh

@@ -0,0 +1,15 @@
+#!/bin/bash
+
+
+HOST=$(hostname)
+echo $HOST
+if [ $HOST == 'hkc-larryc-vm1' ]; then
+	FINOPT_HOME=~/ironfly-workspace/finopt/src
+elif [ $HOST == 'vorsprung' ]; then
+	FINOPT_HOME=~/workspace/finopt/src	
+else
+	FINOPT_HOME=~/l1304/workspace/finopt-ironfly/finopt/src
+fi
+export PYTHONPATH=$FINOPT_HOME:$PYTHONPATH
+#python $FINOPT_HOME/rethink/portfolio_monitor.py  -c -g PM1  
+python $FINOPT_HOME/rethink/portfolio_monitor.py  -g PM1  

+ 2 - 2
src/sh/start_twsgw.sh

@@ -13,7 +13,7 @@ fi
 export PYTHONPATH=$FINOPT_HOME:$PYTHONPATH
 export PYTHONPATH=$FINOPT_HOME:$PYTHONPATH
 #  
 #  
 # clear all topic offsets and erased saved subscriptions
 # clear all topic offsets and erased saved subscriptions
-#python $FINOPT_HOME/comms/ibgw/tws_gateway.py -r -c -f $FINOPT_HOME/config/tws_gateway.cfg 
+python $FINOPT_HOME/comms/ibgw/tws_gateway.py -r -c -f $FINOPT_HOME/config/tws_gateway.cfg 
 
 
 
 
 #
 #
@@ -25,4 +25,4 @@ export PYTHONPATH=$FINOPT_HOME:$PYTHONPATH
 #python $FINOPT_HOME/comms/ibgw/tws_gateway.py  -r -f $FINOPT_HOME/config/tws_gateway.cfg 
 #python $FINOPT_HOME/comms/ibgw/tws_gateway.py  -r -f $FINOPT_HOME/config/tws_gateway.cfg 
 
 
 # normal restart - keep the offsets and reload from saved subscription entries
 # normal restart - keep the offsets and reload from saved subscription entries
-python $FINOPT_HOME/comms/ibgw/tws_gateway.py   -f $FINOPT_HOME/config/tws_gateway.cfg 
+#python $FINOPT_HOME/comms/ibgw/tws_gateway.py   -f $FINOPT_HOME/config/tws_gateway.cfg 

+ 0 - 0
src/ws/__init__.py


+ 65 - 0
src/ws/client.html

@@ -0,0 +1,65 @@
+<html>
+<head>
+  <title>Simple client</title>
+
+  <script type="text/javascript">
+
+    var ws;
+    
+    function init() {
+
+      // Connect to Web Socket
+      ws = new WebSocket("ws://localhost:9001/");
+
+      // Set event handlers.
+      ws.onopen = function() {
+        output("onopen");
+      };
+      
+      ws.onmessage = function(e) {
+        // e.data contains received string.
+        output("onmessage: " + e.data);
+      };
+      
+      ws.onclose = function() {
+        output("onclose");
+      };
+
+      ws.onerror = function(e) {
+        output("onerror");
+        console.log(e)
+      };
+
+    }
+    
+    function onSubmit() {
+      var input = document.getElementById("input");
+      // You can send message to the Web Socket using ws.send.
+      ws.send(input.value);
+      output("send: " + input.value);
+      input.value = "";
+      input.focus();
+    }
+    
+    function onCloseClick() {
+      ws.close();
+    }
+    
+    function output(str) {
+      var log = document.getElementById("log");
+      var escaped = str.replace(/&/, "&amp;").replace(/</, "&lt;").
+        replace(/>/, "&gt;").replace(/"/, "&quot;"); // "
+      log.innerHTML = escaped + "<br>" + log.innerHTML;
+    }
+
+  </script>
+</head>
+<body onload="init();">
+  <form onsubmit="onSubmit(); return false;">
+    <input type="text" id="input">
+    <input type="submit" value="Send">
+    <button onclick="onCloseClick(); return false;">close</button>
+  </form>
+  <div id="log"></div>
+</body>
+</html>

文件差異過大導致無法顯示
+ 66 - 0
src/ws/client_g.html


+ 26 - 0
src/ws/server.py

@@ -0,0 +1,26 @@
+from websocket_server import WebsocketServer
+
+# Called for every client connecting (after handshake)
+def new_client(client, server):
+	print("New client connected and was given id %d" % client['id'])
+	server.send_message_to_all("Hey all, a new client has joined us")
+
+
+# Called for every client disconnecting
+def client_left(client, server):
+	print("Client(%d) disconnected" % client['id'])
+
+
+# Called when a client sends a message
+def message_received(client, server, message):
+	if len(message) > 200:
+		message = message[:200]+'..'
+	print("Client(%d) said: %s" % (client['id'], message))
+
+
+PORT=9001
+server = WebsocketServer(PORT)
+server.set_fn_new_client(new_client)
+server.set_fn_client_left(client_left)
+server.set_fn_message_received(message_received)
+server.run_forever()

+ 62 - 0
src/ws/server2.py

@@ -0,0 +1,62 @@
+from websocket_server import WebsocketServer
+import threading, logging, time, traceback
+import json
+# https://github.com/Pithikos/python-websocket-server
+
+
+class WebSocketServerWrapper(threading.Thread):
+    def __init__(self, name):
+        threading.Thread.__init__(self, name=name)
+        self.clients = {}
+        
+    def set_server(self, server):
+        self.server = server
+        
+    def run(self):   
+        print 'started...'
+        while 1:
+            time.sleep(1.5)
+            #print 'sending stuff.. %s' % str(list(self.clients.iteritems()))
+            map(lambda x: self.server.send_message(x[1], 'msg to %d: %s' % (x[0], time.ctime())), list(self.clients.iteritems()))
+            
+            
+
+            
+    def new_client(self, client, server):
+        print("New client connected and was given id %d" % client['id'])
+        self.clients[client['id']] = client
+        server.send_message_to_all("Hey all, a new client has joined us")
+    
+    
+    # Called for every client disconnecting
+    def client_left(self, client, server):
+        del self.clients[client['id']]
+        print("Client(%d) disconnected" % client['id'])
+    
+    
+    # Called when a client sends a message
+    def message_received(self, client, server, message):
+        if len(message) > 200:
+            message = message[:200]+'..'
+        print("Client(%d) said: %s" % (client['id'], message))
+    
+
+def main():
+    wsw = WebSocketServerWrapper('hello')    
+    wsw.start()
+    PORT=9001
+    server = WebsocketServer(PORT)
+    wsw.set_server(server)
+    server.set_fn_new_client(wsw.new_client)
+    server.set_fn_client_left(wsw.client_left)
+    server.set_fn_message_received(wsw.message_received)
+    server.run_forever()
+    
+    
+if __name__ == "__main__":
+    logging.basicConfig(
+        format='%(asctime)s.%(msecs)s:%(name)s:%(thread)d:%(levelname)s:%(process)d:%(message)s',
+        level=logging.INFO
+        )
+    main()
+    

文件差異過大導致無法顯示
+ 19 - 0
src/ws/server_g1.py


+ 15 - 0
src/ws/setup.py

@@ -0,0 +1,15 @@
+from setuptools import setup, find_packages
+
+setup(
+    name='websocket_server',
+    version='0.4',
+    packages=find_packages("."),
+    url='https://github.com/Pithikos/python-websocket-server',
+    license='MIT',
+    author='Johan Hanssen Seferidis',
+    author_email='manossef@gmail.com',
+    install_requires=[
+    ],
+    description='A simple fully working websocket-server in Python with no external dependencies',
+    platforms='any',
+)

+ 13 - 0
src/ws/tests/README.md

@@ -0,0 +1,13 @@
+Testing
+--------
+
+Run unit tests
+
+    pytest
+
+
+Run functional tests
+
+     python message_lengths.py
+
+Open client.html in the browser and refresh consequently until all test cases pass.

+ 6 - 0
src/ws/tests/_bootstrap_.py

@@ -0,0 +1,6 @@
+#Bootstrap
+import sys, os
+if os.getcwd().endswith('tests'):
+	sys.path.insert(0, '..')
+elif os.getcwd().endswith('websocket-server'):
+	sys.path.insert(0, '.')

+ 57 - 0
src/ws/tests/message_lengths.py

@@ -0,0 +1,57 @@
+import _bootstrap_
+from websocket_server import WebsocketServer
+from time import sleep
+from testsuite.messages import *
+
+'''
+This creates just a server that will send a different message to every new connection:
+
+    1. A message of length less than 126
+    2. A message of length 126
+    3. A message of length 127
+    4. A message of length bigger than 127
+    5. A message above 1024
+    6. A message above 65K
+    7. An enormous message (well beyond 65K)
+
+
+Reconnect to get the next message
+'''
+
+
+counter = 0
+
+# Called for every client connecting (after handshake)
+def new_client(client, server):
+	print("New client connected and was given id %d" % client['id'])
+	global counter
+	if counter == 0:
+		print("Sending message 1 of length %d" % len(msg_125B))
+		server.send_message(client, msg_125B)
+	elif counter == 1:
+		print("Sending message 2 of length %d" % len(msg_126B))
+		server.send_message(client, msg_126B)
+	elif counter == 2:
+		print("Sending message 3 of length %d" % len(msg_127B))
+		server.send_message(client, msg_127B)
+	elif counter == 3:
+		print("Sending message 4 of length %d" % len(msg_208B))
+		server.send_message(client, msg_208B)
+	elif counter == 4:
+		print("Sending message 5 of length %d" % len(msg_1251B))
+		server.send_message(client, msg_1251B)
+	elif counter == 5:
+		print("Sending message 6 of length %d" % len(msg_68KB))
+		server.send_message(client, msg_68KB)
+	elif counter == 6:
+		print("Sending message 7 of length %d" % len(msg_1500KB))
+		server.send_message(client, msg_1500KB)
+	else:
+		print("No errors")
+	counter += 1
+
+
+PORT=9001
+server = WebsocketServer(PORT)
+server.set_fn_new_client(new_client)
+server.run_forever()

+ 29 - 0
src/ws/tests/test_handshake.py

@@ -0,0 +1,29 @@
+import _bootstrap_
+from websocket_server import *
+import pytest
+
+
+class DummyWebsocketHandler(WebSocketHandler):
+    def __init__(self, *_):
+        pass
+
+@pytest.fixture
+def websocket_handler():
+	return DummyWebsocketHandler()
+
+def test_hash_calculations_for_response(websocket_handler):
+	key = 'zyjFH2rQwrTtNFk5lwEMQg=='
+	expected_key = '2hnZADGmT/V1/w1GJYBtttUKASY='
+	assert websocket_handler.calculate_response_key(key) == expected_key
+
+
+def test_response_messages(websocket_handler):
+	key = 'zyjFH2rQwrTtNFk5lwEMQg=='
+	expected = \
+		'HTTP/1.1 101 Switching Protocols\r\n'\
+		'Upgrade: websocket\r\n'              \
+		'Connection: Upgrade\r\n'             \
+		'Sec-WebSocket-Accept: 2hnZADGmT/V1/w1GJYBtttUKASY=\r\n'\
+		'\r\n'
+	handshake_content = websocket_handler.make_handshake_response(key)
+	assert handshake_content == expected

+ 0 - 0
src/ws/tests/testsuite/__init__.py


+ 21 - 0
src/ws/tests/testsuite/messages.py

@@ -0,0 +1,21 @@
+#
+# Fixed messages by length
+# Every message ends with its length..
+#
+
+msg_125B   = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\
+             'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\
+             'abcdefghijklmnopqr125'
+msg_126B   = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\
+             'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\
+             'abcdefghijklmnopqrs126'
+msg_127B   = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\
+             'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\
+             'abcdefghijklmnopqrst127'
+msg_208B   = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\
+             'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\
+             'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\
+             'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvw208'
+msg_1251B  = (msg_125B*10)+'1'       # 1251
+msg_68KB   = ('a'*67995)+'68000'     # 68000
+msg_1500KB = ('a'*1500000)+'1500000' # 1.5Mb

+ 1 - 0
src/ws/websocket_server/__init__.py

@@ -0,0 +1 @@
+from .websocket_server import *

+ 346 - 0
src/ws/websocket_server/websocket_server.py

@@ -0,0 +1,346 @@
+# Author: Johan Hanssen Seferidis
+# License: MIT
+
+import re
+import sys
+import struct
+from base64 import b64encode
+from hashlib import sha1
+import logging
+
+if sys.version_info[0] < 3:
+    from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler
+else:
+    from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler
+
+logger = logging.getLogger(__name__)
+logging.basicConfig()
+
+'''
++-+-+-+-+-------+-+-------------+-------------------------------+
+ 0                   1                   2                   3
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
++-+-+-+-+-------+-+-------------+-------------------------------+
+|F|R|R|R| opcode|M| Payload len |    Extended payload length    |
+|I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
+|N|V|V|V|       |S|             |   (if payload len==126/127)   |
+| |1|2|3|       |K|             |                               |
++-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+|     Extended payload length continued, if payload len == 127  |
++ - - - - - - - - - - - - - - - +-------------------------------+
+|                     Payload Data continued ...                |
++---------------------------------------------------------------+
+'''
+
+FIN    = 0x80
+OPCODE = 0x0f
+MASKED = 0x80
+PAYLOAD_LEN = 0x7f
+PAYLOAD_LEN_EXT16 = 0x7e
+PAYLOAD_LEN_EXT64 = 0x7f
+
+OPCODE_CONTINUATION = 0x0
+OPCODE_TEXT         = 0x1
+OPCODE_BINARY       = 0x2
+OPCODE_CLOSE_CONN   = 0x8
+OPCODE_PING         = 0x9
+OPCODE_PONG         = 0xA
+
+
+# -------------------------------- API ---------------------------------
+
+class API():
+
+    def run_forever(self):
+        try:
+            logger.info("Listening on port %d for clients.." % self.port)
+            self.serve_forever()
+        except KeyboardInterrupt:
+            self.server_close()
+            logger.info("Server terminated.")
+        except Exception as e:
+            logger.error(str(e), exc_info=True)
+            exit(1)
+
+    def new_client(self, client, server):
+        pass
+
+    def client_left(self, client, server):
+        pass
+
+    def message_received(self, client, server, message):
+        pass
+
+    def set_fn_new_client(self, fn):
+        self.new_client = fn
+
+    def set_fn_client_left(self, fn):
+        self.client_left = fn
+
+    def set_fn_message_received(self, fn):
+        self.message_received = fn
+
+    def send_message(self, client, msg):
+        self._unicast_(client, msg)
+
+    def send_message_to_all(self, msg):
+        self._multicast_(msg)
+
+
+# ------------------------- Implementation -----------------------------
+
+class WebsocketServer(ThreadingMixIn, TCPServer, API):
+    """
+	A websocket server waiting for clients to connect.
+
+    Args:
+        port(int): Port to bind to
+        host(str): Hostname or IP to listen for connections. By default 127.0.0.1
+            is being used. To accept connections from any client, you should use
+            0.0.0.0.
+        loglevel: Logging level from logging module to use for logging. By default
+            warnings and errors are being logged.
+
+    Properties:
+        clients(list): A list of connected clients. A client is a dictionary
+            like below.
+                {
+                 'id'      : id,
+                 'handler' : handler,
+                 'address' : (addr, port)
+                }
+    """
+
+    allow_reuse_address = True
+    daemon_threads = True  # comment to keep threads alive until finished
+
+    clients = []
+    id_counter = 0
+
+    def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING):
+        logger.setLevel(loglevel)
+        self.port = port
+        TCPServer.__init__(self, (host, port), WebSocketHandler)
+
+    def _message_received_(self, handler, msg):
+        self.message_received(self.handler_to_client(handler), self, msg)
+
+    def _ping_received_(self, handler, msg):
+        handler.send_pong(msg)
+
+    def _pong_received_(self, handler, msg):
+        pass
+
+    def _new_client_(self, handler):
+        self.id_counter += 1
+        client = {
+            'id': self.id_counter,
+            'handler': handler,
+            'address': handler.client_address
+        }
+        self.clients.append(client)
+        self.new_client(client, self)
+
+    def _client_left_(self, handler):
+        client = self.handler_to_client(handler)
+        self.client_left(client, self)
+        if client in self.clients:
+            self.clients.remove(client)
+
+    def _unicast_(self, to_client, msg):
+        to_client['handler'].send_message(msg)
+
+    def _multicast_(self, msg):
+        for client in self.clients:
+            self._unicast_(client, msg)
+
+    def handler_to_client(self, handler):
+        for client in self.clients:
+            if client['handler'] == handler:
+                return client
+
+
+class WebSocketHandler(StreamRequestHandler):
+
+    def __init__(self, socket, addr, server):
+        self.server = server
+        StreamRequestHandler.__init__(self, socket, addr, server)
+
+    def setup(self):
+        StreamRequestHandler.setup(self)
+        self.keep_alive = True
+        self.handshake_done = False
+        self.valid_client = False
+
+    def handle(self):
+        while self.keep_alive:
+            if not self.handshake_done:
+                self.handshake()
+            elif self.valid_client:
+                self.read_next_message()
+
+    def read_bytes(self, num):
+        # python3 gives ordinal of byte directly
+        bytes = self.rfile.read(num)
+        if sys.version_info[0] < 3:
+            return map(ord, bytes)
+        else:
+            return bytes
+
+    def read_next_message(self):
+        try:
+            b1, b2 = self.read_bytes(2)
+        except ValueError as e:
+            b1, b2 = 0, 0
+
+        fin    = b1 & FIN
+        opcode = b1 & OPCODE
+        masked = b2 & MASKED
+        payload_length = b2 & PAYLOAD_LEN
+
+        if not b1:
+            logger.info("Client closed connection.")
+            self.keep_alive = 0
+            return
+        if opcode == OPCODE_CLOSE_CONN:
+            logger.info("Client asked to close connection.")
+            self.keep_alive = 0
+            return
+        if not masked:
+            logger.warn("Client must always be masked.")
+            self.keep_alive = 0
+            return
+        if opcode == OPCODE_CONTINUATION:
+            logger.warn("Continuation frames are not supported.")
+            return
+        elif opcode == OPCODE_BINARY:
+            logger.warn("Binary frames are not supported.")
+            return
+        elif opcode == OPCODE_TEXT:
+            opcode_handler = self.server._message_received_
+        elif opcode == OPCODE_PING:
+            opcode_handler = self.server._ping_received_
+        elif opcode == OPCODE_PONG:
+            opcode_handler = self.server._pong_received_
+        else:
+            logger.warn("Unknown opcode %#x." + opcode)
+            self.keep_alive = 0
+            return
+
+        if payload_length == 126:
+            payload_length = struct.unpack(">H", self.rfile.read(2))[0]
+        elif payload_length == 127:
+            payload_length = struct.unpack(">Q", self.rfile.read(8))[0]
+
+        masks = self.read_bytes(4)
+        decoded = ""
+        for char in self.read_bytes(payload_length):
+            char ^= masks[len(decoded) % 4]
+            decoded += chr(char)
+        opcode_handler(self, decoded)
+
+    def send_message(self, message):
+        self.send_text(message)
+
+    def send_pong(self, message):
+        self.send_text(message, OPCODE_PONG)
+
+    def send_text(self, message, opcode=OPCODE_TEXT):
+        """
+        Important: Fragmented(=continuation) messages are not supported since
+        their usage cases are limited - when we don't know the payload length.
+        """
+
+        # Validate message
+        if isinstance(message, bytes):
+            message = try_decode_UTF8(message)  # this is slower but ensures we have UTF-8
+            if not message:
+                logger.warning("Can\'t send message, message is not valid UTF-8")
+                return False
+        elif isinstance(message, str) or isinstance(message, unicode):
+            pass
+        else:
+            logger.warning('Can\'t send message, message has to be a string or bytes. Given type is %s' % type(message))
+            return False
+
+        header  = bytearray()
+        payload = encode_to_UTF8(message)
+        payload_length = len(payload)
+
+        # Normal payload
+        if payload_length <= 125:
+            header.append(FIN | opcode)
+            header.append(payload_length)
+
+        # Extended payload
+        elif payload_length >= 126 and payload_length <= 65535:
+            header.append(FIN | opcode)
+            header.append(PAYLOAD_LEN_EXT16)
+            header.extend(struct.pack(">H", payload_length))
+
+        # Huge extended payload
+        elif payload_length < 18446744073709551616:
+            header.append(FIN | opcode)
+            header.append(PAYLOAD_LEN_EXT64)
+            header.extend(struct.pack(">Q", payload_length))
+
+        else:
+            raise Exception("Message is too big. Consider breaking it into chunks.")
+            return
+
+        self.request.send(header + payload)
+
+    def handshake(self):
+        message = self.request.recv(1024).decode().strip()
+        upgrade = re.search('\nupgrade[\s]*:[\s]*websocket', message.lower())
+        if not upgrade:
+            self.keep_alive = False
+            return
+        key = re.search('\n[sS]ec-[wW]eb[sS]ocket-[kK]ey[\s]*:[\s]*(.*)\r\n', message)
+        if key:
+            key = key.group(1)
+        else:
+            logger.warning("Client tried to connect but was missing a key")
+            self.keep_alive = False
+            return
+        response = self.make_handshake_response(key)
+        self.handshake_done = self.request.send(response.encode())
+        self.valid_client = True
+        self.server._new_client_(self)
+
+    def make_handshake_response(self, key):
+        return \
+          'HTTP/1.1 101 Switching Protocols\r\n'\
+          'Upgrade: websocket\r\n'              \
+          'Connection: Upgrade\r\n'             \
+          'Sec-WebSocket-Accept: %s\r\n'        \
+          '\r\n' % self.calculate_response_key(key)
+
+    def calculate_response_key(self, key):
+        GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
+        hash = sha1(key.encode() + GUID.encode())
+        response_key = b64encode(hash.digest()).strip()
+        return response_key.decode('ASCII')
+
+    def finish(self):
+        self.server._client_left_(self)
+
+
+def encode_to_UTF8(data):
+    try:
+        return data.encode('UTF-8')
+    except UnicodeEncodeError as e:
+        logger.error("Could not encode data to UTF-8 -- %s" % e)
+        return False
+    except Exception as e:
+        raise(e)
+        return False
+
+
+def try_decode_UTF8(data):
+    try:
+        return data.decode('utf-8')
+    except UnicodeDecodeError:
+        return False
+    except Exception as e:
+        raise(e)

部分文件因文件數量過多而無法顯示