Torch modelΒΆ

Example of using a Torch RNN model along with Vectorize to fit a set of string expressions.

import numpy as np
from itertools import product

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Adam

from grammaropt.grammar import build_grammar
from grammaropt.grammar import Vectorizer
from grammaropt.grammar import as_str
from grammaropt.grammar import NULL_SYMBOL
from grammaropt.rnn import RnnAdapter
from grammaropt.rnn import RnnModel
from grammaropt.rnn import RnnWalker

def acc(pred, true_classes):
    _, pred_classes = pred.max(1)
    acc = (pred_classes == true_classes).float().mean()
    return acc


# Grammar and corpus
rules = r"""
    S = (T "+" S) / (T "*" S) / (T "/" S) / T
    T = (po S pc) / ("sin" po S pc) / ("cos" po S pc) / ("exp" po S pc) / "x" / int
    po = "("
    pc = ")"
    int = "0" / "1" / "2" / "3" / "4" / "5" / "6" / "7" / "8" / "9"
"""
grammar = build_grammar(rules)
corpus = [
    'x*{}+{}'.format(i, j)
    for i, j in product(range(10), range(10))
]
vect = Vectorizer(grammar, pad=True)
X = vect.transform(corpus)
X = [[0] + x for x in X]
X = np.array(X).astype('int32')

# Model
max_length = max(map(len, X))
vocab_size = len(vect.tok_to_id)
emb_size = 32
batch_size = 32
hidden_size = 32
epochs = 1000
model = RnnModel(vocab_size=vocab_size, emb_size=emb_size, hidden_size=hidden_size)
optim = Adam(model.parameters(), lr=1e-3)
adp = RnnAdapter(model, tok_to_id=vect.tok_to_id, begin_tok=NULL_SYMBOL)
wl = RnnWalker(grammar, adp, temperature=1.0, min_depth=1, max_depth=5)

# Training
I = X[:, 0:-1]
O = X[:, 1:]
crit = nn.CrossEntropyLoss()
avg_loss = 0.
avg_precision = 0.
for i in range(epochs):
    for j in range(0, len(I), batch_size):
        inp = I[j:j+batch_size]
        out = O[j:j+batch_size]
        out = out.flatten()
        inp = torch.from_numpy(inp).long()
        inp = Variable(inp)
        out = torch.from_numpy(out).long()
        out = Variable(out)

        model.zero_grad()
        y = model(inp)
        loss = crit(y, out)
        precision = acc(y, out)
        loss.backward()
        optim.step()

        avg_loss = avg_loss * 0.9 + loss.data[0] * 0.1
        avg_precision = avg_precision * 0.9 + precision.data[0] * 0.1
        if i % 10 == 0:
            print('Epoch : {:05d} Avg loss : {:.6f} Avg Precision : {:.6f}'.format(i, avg_loss, avg_precision))
            print('Generated :')
            wl.walk()
            expr = as_str(wl.terminals)
            print(expr)

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery