PyAlgoTradeの日本語解説ブログ

PyAlgoTrade の勝手に日本語解説ブログ。日本語の内容に関して保証はいたしておりません。必ず本家のサイトをご確認ください。

2015年9月1日火曜日

PyAlgoTrade のチュートリアルをかいつまんで訳す その3 最適化

オプティマイザを使いましょう。アイデアは非常に簡単です:


  • 担当する1つのサーバがあります:
  • 戦略を実行するためのBarを提供します。
  • 戦略を実行するためのパラメータを提供します。
  • ワーカーのそれぞれの戦略の結果を記録します。
  • 担当する複数のワーカーがいます。
  • サーバが提供するバーやパラメータを使用して戦略を実行します。


これらを説明するために、我々はRSI2として知られている戦略を使用します
http://stockcharts.com/school/doku.php?id=chart_school:trading_strategies:rsi2)


  • トレンドを識別するための単純移動平均期間。この市場参入を単純移動平均150と250の間の範囲で試します。
  • 出口点に対してより小さい単純移動平均期間。この脱出を単純移動平均5と15の間の範囲で試します。。
  • ショート/ロング・ポジションの両方を入力するためのRSI期間。このrsiPeriodを2と10の間の範囲で試します。
  • ロングポジションエントリのRSI売られ過ぎのしきい値。我々は、このoverSoldThresholdを5と25の範囲で試します。
  • ショートポジションエントリのRSI買わしきい値。我々は、このoverBoughtThresholdを75と95の間の範囲で試します。

私の数学がOKであれば、それらは4409559異なる組み合わせです。

パラメータの一組のためにこの戦略をテストすることは私には約0.16秒を要しました。
すべての組み合わせを実行するために、約8.5日かかります。
それは長い時間ですが、8コアのコンピュータを得ることができるならば、合計時間は約2.5時間まで減らすことができます。
手っ取り早く言うと、並列演算が必要です。



「ダウ・ジョーンズ工業株平均」の毎日のバーの3年をダウンロードして起動してみましょう:

python -c "from pyalgotrade.tools import yahoofinance; yahoofinance.download_daily_bars('dia', 2009, 'dia-2009.csv')"
python -c "from pyalgotrade.tools import yahoofinance; yahoofinance.download_daily_bars('dia', 2010, 'dia-2010.csv')"
python -c "from pyalgotrade.tools import yahoofinance; yahoofinance.download_daily_bars('dia', 2011, 'dia-2011.csv')"
rsi2.pyとして以下のコードを保存します。
from pyalgotrade import strategy
from pyalgotrade.technical import ma
from pyalgotrade.technical import rsi
from pyalgotrade.technical import cross


class RSI2(strategy.BacktestingStrategy):
    def __init__(self, feed, instrument, entrySMA, exitSMA, rsiPeriod, overBoughtThreshold, overSoldThreshold):
        strategy.BacktestingStrategy.__init__(self, feed)
        self.__instrument = instrument
        # We'll use adjusted close values, if available, instead of regular close values.
        if feed.barsHaveAdjClose():
            self.setUseAdjustedValues(True)
        self.__priceDS = feed[instrument].getPriceDataSeries()
        self.__entrySMA = ma.SMA(self.__priceDS, entrySMA)
        self.__exitSMA = ma.SMA(self.__priceDS, exitSMA)
        self.__rsi = rsi.RSI(self.__priceDS, rsiPeriod)
        self.__overBoughtThreshold = overBoughtThreshold
        self.__overSoldThreshold = overSoldThreshold
        self.__longPos = None
        self.__shortPos = None

    def getEntrySMA(self):
        return self.__entrySMA

    def getExitSMA(self):
        return self.__exitSMA

    def getRSI(self):
        return self.__rsi

    def onEnterCanceled(self, position):
        if self.__longPos == position:
            self.__longPos = None
        elif self.__shortPos == position:
            self.__shortPos = None
        else:
            assert(False)

    def onExitOk(self, position):
        if self.__longPos == position:
            self.__longPos = None
        elif self.__shortPos == position:
            self.__shortPos = None
        else:
            assert(False)

    def onExitCanceled(self, position):
        # If the exit was canceled, re-submit it.
        position.exitMarket()

    def onBars(self, bars):
        # Wait for enough bars to be available to calculate SMA and RSI.
        if self.__exitSMA[-1] is None or self.__entrySMA[-1] is None or self.__rsi[-1] is None:
            return

        bar = bars[self.__instrument]
        if self.__longPos is not None:
            if self.exitLongSignal():
                self.__longPos.exitMarket()
        elif self.__shortPos is not None:
            if self.exitShortSignal():
                self.__shortPos.exitMarket()
        else:
            if self.enterLongSignal(bar):
                shares = int(self.getBroker().getCash() * 0.9 / bars[self.__instrument].getPrice())
                self.__longPos = self.enterLong(self.__instrument, shares, True)
            elif self.enterShortSignal(bar):
                shares = int(self.getBroker().getCash() * 0.9 / bars[self.__instrument].getPrice())
                self.__shortPos = self.enterShort(self.__instrument, shares, True)

    def enterLongSignal(self, bar):
        return bar.getPrice() > self.__entrySMA[-1] and self.__rsi[-1] <= self.__overSoldThreshold

    def exitLongSignal(self):
        return cross.cross_above(self.__priceDS, self.__exitSMA) and not self.__longPos.exitActive()

    def enterShortSignal(self, bar):
        return bar.getPrice() < self.__entrySMA[-1] and self.__rsi[-1] >= self.__overBoughtThreshold

    def exitShortSignal(self):
        return cross.cross_below(self.__priceDS, self.__exitSMA) and not self.__shortPos.exitActive()

サーバーのスクリプトは以下になります。
import itertools
from pyalgotrade.barfeed import yahoofeed
from pyalgotrade.optimizer import server


def parameters_generator():
    instrument = ["dia"]
    entrySMA = range(150, 251)
    exitSMA = range(5, 16)
    rsiPeriod = range(2, 11)
    overBoughtThreshold = range(75, 96)
    overSoldThreshold = range(5, 26)
    return itertools.product(instrument, entrySMA, exitSMA, rsiPeriod, overBoughtThreshold, overSoldThreshold)

# The if __name__ == '__main__' part is necessary if running on Windows.
if __name__ == '__main__':
    # Load the feed from the CSV files.
    feed = yahoofeed.Feed()
    feed.addBarsFromCSV("dia", "dia-2009.csv")
    feed.addBarsFromCSV("dia", "dia-2010.csv")
    feed.addBarsFromCSV("dia", "dia-2011.csv")

    # Run the server.
    server.serve(feed, parameters_generator(), "localhost", 5000)

サーバーのコードは3つのことを行っています。
組み合わせストラテジーの異なるパラメータの生成
CSVファイルからフィードをダウンロード
サーバーを実行してローカル接続port5000で待ち受け

ワーカースクリプトを実行して待ち受けているサーバーを呼び出します。
from pyalgotrade.optimizer import worker
import rsi2

# The if __name__ == '__main__' part is necessary if running on Windows.
if __name__ == '__main__':
    worker.run(rsi2.RSI2, "localhost", 5000, workerName="localworker")

コンソールを2つ開いて、サーバーのスクリプトを実行した後、別のコンソールでワーカーを実行する それぞれの出力は以下のようなになる。 サーバー
2014-05-03 15:04:01,083 server [INFO] Loading bars
2014-05-03 15:04:01,348 server [INFO] Waiting for workers
2014-05-03 15:04:58,277 server [INFO] Partial result 1242173.28754 with parameters: ('dia', 150, 5, 2, 91, 19) from localworker
2014-05-03 15:04:58,566 server [INFO] Partial result 1203266.33502 with parameters: ('dia', 150, 5, 2, 81, 19) from localworker
2014-05-03 15:05:50,965 server [INFO] Partial result 1220763.1579 with parameters: ('dia', 150, 5, 3, 83, 24) from localworker
2014-05-03 15:05:51,325 server [INFO] Partial result 1221627.50793 with parameters: ('dia', 150, 5, 3, 80, 24) from localworker
.
.
ワーカー
2014-05-03 15:02:25,377 localworker [INFO] Running strategy with parameters ('dia', 150, 5, 2, 94, 5)
2014-05-03 15:02:25,661 localworker [INFO] Result 1090481.06342
2014-05-03 15:02:25,661 localworker [INFO] Result 1031470.23717
2014-05-03 15:02:25,662 localworker [INFO] Running strategy with parameters ('dia', 150, 5, 2, 93, 25)
2014-05-03 15:02:25,665 localworker [INFO] Running strategy with parameters ('dia', 150, 5, 2, 84, 14)
2014-05-03 15:02:25,995 localworker [INFO] Result 1135558.55667
2014-05-03 15:02:25,996 localworker [INFO] Running strategy with parameters ('dia', 150, 5, 2, 93, 24)
2014-05-03 15:02:26,006 localworker [INFO] Result 1083987.18174
2014-05-03 15:02:26,007 localworker [INFO] Running strategy with parameters ('dia', 150, 5, 2, 84, 13)
2014-05-03 15:02:26,256 localworker [INFO] Result 1093736.17175
2014-05-03 15:02:26,257 localworker [INFO] Running strategy with parameters ('dia', 150, 5, 2, 84, 12)
2014-05-03 15:02:26,280 localworker [INFO] Result 1135558.5566


一つのサーバーで複数のワーカーを走らせたほうがいい。というか1ワーカーだとものすごく待つ。

サーバーではなくデスクトップ機で並列演算をさせたいのであれば、 pyalgotrade.optimizer.local モジュールを使いましょう。


import itertools
from pyalgotrade.optimizer import local
from pyalgotrade.barfeed import yahoofeed
import rsi2


def parameters_generator():
    instrument = ["dia"]
    entrySMA = range(150, 251)
    exitSMA = range(5, 16)
    rsiPeriod = range(2, 11)
    overBoughtThreshold = range(75, 96)
    overSoldThreshold = range(5, 26)
    return itertools.product(instrument, entrySMA, exitSMA, rsiPeriod, overBoughtThreshold, overSoldThreshold)


# The if __name__ == '__main__' part is necessary if running on Windows.
if __name__ == '__main__':
    # Load the feed from the CSV files.
    feed = yahoofeed.Feed()
    feed.addBarsFromCSV("dia", "dia-2009.csv")
    feed.addBarsFromCSV("dia", "dia-2010.csv")
    feed.addBarsFromCSV("dia", "dia-2011.csv")

    local.run(rsi2.RSI2, feed, parameters_generator())
このコードが実行しているのは3つ

違うパラメーターのジェネレーター関数の宣言
CSVファイルの読み込み
pyalgotrade.optimizer.local モジュールを使って並列実行し最適な結果を求める

ちなみに、ベストな組み合わせを使った時の結果は$1000が$2314.40になった。

  • entrySMA: 154
  • exitSMA: 5
  • rsiPeriod: 2
  • overBoughtThreshold: 91
  • overSoldThreshold: 18


0 件のコメント:

コメントを投稿