예제 #1
0
def memory_update_test():
    memory = mx.symbol.Variable('memory')
    update = mx.symbol.Variable('update')
    flag = mx.symbol.Variable('flag')
    update_factor = mx.symbol.Variable('update_factor')
    output = mx.symbol.MemoryUpdate(data=memory, update=update, flag=flag, factor=update_factor)
    output2 = mx.symbol.MemoryUpdate(data=output, update=update, flag=flag, factor=update_factor)
    output2 = mx.symbol.BlockGrad(data=output2)
    data_shapes = {'memory': (5, 3, 2, 2), 'update': (1, 3, 2, 2), 'flag': (5, ),
                   'update_factor':(1,)}
    net = Base(sym=output2, data_shapes=data_shapes)

    memory_npy = numpy.zeros((5, 3, 2, 2), dtype=numpy.float32)
    update_npy = numpy.zeros((1, 3, 2, 2), dtype=numpy.float32)
    flag_npy = numpy.zeros((5,), dtype=numpy.float32)
    update_factor_npy = numpy.array([0.8,])
    for i in range(5):
        memory_npy[i, :, :, :] = 2*i + 1
    flag_npy[1] = 1
    output_npy = net.forward(data_shapes=data_shapes, memory=memory_npy, update=update_npy, flag=flag_npy,
                update_factor=update_factor_npy)[0].asnumpy()
    net.backward(data_shapes=data_shapes, memory=memory_npy, update=update_npy, flag=flag_npy,
                update_factor=update_factor_npy)
    print memory_npy
    print output_npy
예제 #2
0
def speed_compare_cufft_fftw():
    a = mx.symbol.Variable('a')
    b = mx.symbol.Variable('b')
    a_fft = mx.symbol.FFT2D(data=a, batchsize=32)
    a_recons = mx.symbol.IFFT2D(data=a_fft, output_shape=(64, 64), batchsize=32)
    c = mx.symbol.Flatten(a_fft)
    d_pred = mx.symbol.FullyConnected(data=c, num_hidden=64)
    d = mx.symbol.LinearRegressionOutput(data=d_pred, label=b)
    data_shapes = {'a': (10, 96*3, 64, 64)}
    a_npy = numpy.zeros((1, 1, 3, 3))
    a_npy[0,0,0,:] = numpy.array([1,2,1])
    a_npy[0,0,1,:] = numpy.array([2,3,2])
    a_npy[0,0,2,:] = numpy.array([2,3,4])
    optimizer = mx.optimizer.create(name='sgd', learning_rate=0.01,
                                    clip_gradient=None,
                                    rescale_grad=1.0, wd=0.)
    updater = mx.optimizer.get_updater(optimizer)
    net = Base(sym=a_recons, data_shapes=data_shapes, name='net', ctx=mx.gpu())
    cpu_time = 0
    gpu_time = 0
    data = numpy.zeros((10, 96*3, 64, 64))
    for i in range(100):
        data[:] = numpy.random.rand(10, 96*3, 64, 64)
        start = time.time()
        output_fftw = pyfftw.interfaces.numpy_fft.irfft2(pyfftw.interfaces.numpy_fft.rfft2(data))
        end = time.time()
        cpu_time += end-start
        start = time.time()
        output_mxnet = net.forward(is_train=False, data_shapes=data_shapes, a=nd.array(data, ctx=mx.gpu()))[0].asnumpy()
        end = time.time()
        gpu_time += end-start
        print numpy.square(output_mxnet - output_fftw.real).sum()
    print cpu_time
    print gpu_time
예제 #3
0
def memory_stat_read_write():
    counter = mx.symbol.Variable('counter')
    visiting_timestamp = mx.symbol.Variable('visiting_timestamp')
    control_flag = mx.symbol.Variable('control_flag')
    memory_write_control_op = MemoryStatUpdateOp(mode='write')
    memory_read_control_op = MemoryStatUpdateOp(mode='read')
    controlling_stats_afterwrite = memory_write_control_op(counter=counter, visiting_timestamp=visiting_timestamp,
                                                        control_flag=control_flag)
    controlling_stats_afterread = memory_read_control_op(counter=counter, visiting_timestamp=visiting_timestamp,
                                                control_flag=control_flag)
    data_shapes = {'counter': (1, 4), 'visiting_timestamp': (1, 4), 'control_flag':(1,)}

    write_net = Base(sym=controlling_stats_afterwrite, data_shapes=data_shapes, name='write_net')
    read_net = Base(sym=controlling_stats_afterread, data_shapes=data_shapes, name='read_net')

    current_counter = numpy.array([[10, 20, 3, 40]])
    current_visiting_timestamp = numpy.array([[1, 3, 2, 4]])

    for i in range(100):
        write_outputs = write_net.forward(data_shapes=data_shapes, counter=current_counter,
                              visiting_timestamp=current_visiting_timestamp, control_flag=numpy.array([i%3,]))
        write_net.backward(data_shapes=data_shapes)
        current_counter = write_outputs[0].asnumpy()
        current_visiting_timestamp = write_outputs[1].asnumpy()
        print 'Control Flag:', i%3, 'Counter:', current_counter, \
            " Visiting Timestamp:", current_visiting_timestamp, "Flags:", write_outputs[2].asnumpy()

        read_outputs = read_net.forward(data_shapes=data_shapes, counter=current_counter,
                              visiting_timestamp=current_visiting_timestamp, control_flag=numpy.array([i%4,]))
        read_net.backward(data_shapes=data_shapes)
        current_counter = read_outputs[0].asnumpy()
        current_visiting_timestamp = read_outputs[1].asnumpy()
        print 'Control Flag:', i%4, 'Counter:', current_counter, \
            " Visiting Timestamp:", current_visiting_timestamp
        ch = raw_input()
예제 #4
0
def complex_conjugate():
    data = mx.symbol.Variable('data')
    conjugate = mx.symbol.Conjugate(data=data)
    conjugate = mx.symbol.BlockGrad(data=conjugate)
    data_shapes = {'data': (1, 1, 4, 4)}
    net = Base(sym=conjugate, data_shapes=data_shapes)
    data_npy = numpy.ones((1, 1, 4, 4))
    output_npy = net.forward(data_shapes=data_shapes, data=data_npy)[0].asnumpy()
    net.backward(data_shapes=data_shapes, data=data_npy)
    print output_npy
    print output_npy.shape
예제 #5
0
def broadcast_channel():
    data = mx.symbol.Variable('data')
    broadcast = mx.symbol.BroadcastChannel(data=data, dim=0, size=10)
    broadcast = mx.symbol.BlockGrad(data=broadcast)
    data_shapes = {'data': (1, 1, 4, 4)}
    net = Base(sym=broadcast, data_shapes=data_shapes)
    data_npy = numpy.random.rand(1, 1, 4, 4)
    output_npy = net.forward(data_shapes=data_shapes, data=data_npy)[0].asnumpy()
    net.backward(data_shapes=data_shapes, data=data_npy)
    print output_npy
    print output_npy.shape
예제 #6
0
def sum_channel_test():
    data = mx.symbol.Variable('data')
    summed_data = mx.symbol.SumChannel(data=data, dim=3)
    summed_data = mx.symbol.BlockGrad(data=summed_data)
    data_shapes = {'data': (10, 9, 8, 7)}
    net = Base(sym=summed_data, data_shapes=data_shapes)
    data_npy = numpy.ones((10,9,8,7))
    output_npy = net.forward(data_shapes=data_shapes, data=data_npy)[0].asnumpy()
    net.backward(data_shapes=data_shapes, data=data_npy)
    print output_npy
    print output_npy.shape
예제 #7
0
def complex_hadamard_test():
    ldata = mx.symbol.Variable('ldata')
    rdata = mx.symbol.Variable('rdata')
    product = mx.symbol.ComplexHadamard(ldata=ldata, rdata=rdata)
    product = mx.symbol.BlockGrad(data=product)
    data_shapes = {'ldata': (1, 1, 4, 4), 'rdata': (1, 1, 4, 4)}
    net = Base(sym=product, data_shapes=data_shapes)
    ldata_npy = numpy.ones((1, 1, 4, 4))
    rdata_npy = numpy.ones((1, 1, 4, 4))
    ldata_npy[0,0,0,0] = 2
    rdata_npy[0,0,1,0] = -1
    output_npy = net.forward(data_shapes=data_shapes, ldata=ldata_npy, rdata=rdata_npy)[0].asnumpy()
    net.backward(data_shapes=data_shapes, ldata=ldata_npy, rdata=rdata_npy)
    print output_npy
    print output_npy.shape
예제 #8
0
def memory_choose_test():
    memory = mx.symbol.Variable('memory')
    index = mx.symbol.Variable('index')
    chosen_unit = mx.symbol.MemoryChoose(data=memory, index=index)
    chosen_unit = mx.symbol.BlockGrad(data=chosen_unit)
    data_shapes ={'memory': (5, 4, 3, 3), 'index': (1,)}
    net = Base(sym=chosen_unit, data_shapes=data_shapes)
    memory_npy = numpy.zeros((5, 4, 3, 3), dtype=numpy.float32)
    for i in range(5):
        memory_npy[i, :, :, :] = i
    index_npy = numpy.array([3], dtype=numpy.float32)
    print net.internal_sym_names

    output = net.forward(data_shapes=data_shapes, memory=memory_npy, index=index_npy)[0].asnumpy()
    net.backward(data_shapes=data_shapes, memory=memory_npy, index=index_npy)
    print output
    print output.shape
예제 #9
0
def test_mxnet_conj():
    a = mx.symbol.Variable('a')
    b = mx.symbol.conj(a)
    base_shape = (2, 10)
    data_shapes = {'a': base_shape}
    a_npy = numpy.random.rand(*base_shape)
    out_grad_npy = numpy.random.rand(*base_shape)
    net = Base(sym=b, data_shapes=data_shapes)
    outputs = net.forward(is_train=True, a=a_npy)
    print 'conj:'
    print numpy.square(outputs[0].asnumpy()[:, ::2] - a_npy[:, ::2]).sum()
    print numpy.square(outputs[0].asnumpy()[:, 1::2] + a_npy[:, 1::2]).sum()
    net.backward(out_grads=[nd.array(out_grad_npy, ctx=mx.gpu())])
    print numpy.square(net.exe.grad_dict['a'].asnumpy()[:, ::2] -
                       out_grad_npy[:, ::2]).sum()
    print numpy.square(net.exe.grad_dict['a'].asnumpy()[:, 1::2] +
                       out_grad_npy[:, 1::2]).sum()
예제 #10
0
    def speed_mxnet_test():
        prob = mx.symbol.Variable('prob')
        mean = mx.symbol.Variable('mean')
        var = mx.symbol.Variable('var')
        score = mx.symbol.Variable('score')
        out = mx.symbol.Custom(prob=prob, mean=mean, var=var, score=score, name='policy', op_type='LogMoGPolicy')
        data_shapes = {'prob': (batch_size, num_centers), 'mean': (batch_size, num_centers, sample_dim),
                       'var': (batch_size, num_centers, sample_dim), 'score': (batch_size,)}

        net = Base(sym=out, data_shapes=data_shapes, ctx=mx.cpu())
        sample_npy = numpy.empty((total_sample_num, mean_npy.shape[0], mean_npy.shape[2]), dtype=numpy.float32)
        for i in range(total_sample_num):
            if 0 == i:
                sample_npy[i, :, :] = net.forward(is_train=False, prob=prob_npy, mean=mean_npy,
                                                  var=var_npy, score=score_npy)[0].asnumpy()
            else:
                sample_npy[i, :, :] = net.forward(is_train=False)[0].asnumpy()
        plt.hist2d(sample_npy[:, 1, 0], sample_npy[:, 1, 1], (200, 200), cmap=plt.cm.jet)
        plt.colorbar()
        plt.show()
예제 #11
0
파일: test_policy.py 프로젝트: flyers/Arena
def test_logsoftmax():
    var = mx.symbol.Variable('var')
    data = mx.symbol.Variable('data')
    net = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10)
    net = mx.symbol.Activation(data=net, name='relu1', act_type='relu')
    net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=4)
    net = mx.symbol.Custom(data=net, name='policy', op_type='LogSoftmaxPolicy')
    ctx = mx.gpu()
    minibatch_size = 100
    data_shapes = {
        'data': (minibatch_size, 4),
        'policy_score': (minibatch_size, )
    }
    qnet = Base(data_shapes=data_shapes,
                sym_gen=net,
                name='PolicyNet',
                initializer=mx.initializer.Xavier(factor_type="in",
                                                  magnitude=1.0),
                ctx=ctx)
    print qnet.internal_sym_names

    lr = 0.00001
    optimizer = mx.optimizer.create(name='sgd',
                                    learning_rate=0.00001,
                                    clip_gradient=None,
                                    rescale_grad=1.0,
                                    wd=0.)
    updater = mx.optimizer.get_updater(optimizer)
    total_iter = 1000000
    stats = numpy.zeros((total_iter, 3), dtype=numpy.float32)
    plt.ion()
    fig, ax = plt.subplots()
    lines, = ax.plot([], [])
    ax.set_autoscaley_on(True)
    baseline = 0
    for i in range(total_iter):
        data = numpy.random.randn(minibatch_size, 4)
        outputs = qnet.forward(is_train=True, data=data)
        action = outputs[0].asnumpy()
        prob = outputs[1].asnumpy()
        #print 'data=', data, 'action=', action, 'prob=', prob
        #ch = raw_input()
        score = simple_game_discrete(data, action)
        baseline = baseline - 0.001 * (baseline - score.mean())
        print 'score=', score.mean(), 'acc=', numpy.sum(
            action == numpy.argmax(data *
                                   data, axis=1)).mean(), 'baseline=', baseline
        stats[i] = [
            score.mean(),
            numpy.sum(action == numpy.argmax(data * data, axis=1)).mean(),
            baseline
        ]
        qnet.backward(policy_score=score - baseline)
        qnet.update(updater)
        update_line(lines, fig, ax, i,
                    score.mean())  # numpy.square(means - data*data).mean())
예제 #12
0
def mog_backward_test(batch_size=5, num_centers=11, sample_dim=33):
    prob = mx.symbol.Variable('prob')
    mean = mx.symbol.Variable('mean')
    var = mx.symbol.Variable('var')
    score = mx.symbol.Variable('score')
    out = mx.symbol.Custom(prob=prob, mean=mean, var=var, score=score, name='policy', op_type='LogMoGPolicy', implicit_backward=False)
    data_shapes = {'prob': (batch_size, num_centers), 'mean': (batch_size, num_centers, sample_dim),
                   'var': (batch_size, num_centers, sample_dim), 'score': (batch_size,),
                   'policy_backward_action': (batch_size, sample_dim)}
    net = Base(sym=out, data_shapes=data_shapes, ctx=mx.cpu())

    prob_npy = get_numpy_rng().rand(batch_size, num_centers)
    mean_npy = get_numpy_rng().rand(batch_size, num_centers, sample_dim) * 1 + 5
    var_npy = get_numpy_rng().rand(batch_size, num_centers, sample_dim) * 2 + 0.001
    prob_npy = prob_npy / prob_npy.sum(axis=1).reshape(prob_npy.shape[0], 1)
    score_npy = get_numpy_rng().rand(batch_size, )
    sample_npy = get_numpy_rng().rand(batch_size, sample_dim) * 1  + 5
    net.forward(is_train=True, prob=prob_npy, mean=mean_npy, var=var_npy)
    net.backward(score=score_npy, policy_backward_action=sample_npy)
    def fd_grad():
        eps = 1E-8
        base_loglikelihood = logmog(prob=prob_npy, mean=mean_npy, var=var_npy, score=score_npy, sample=sample_npy)
        fd_prob_grad = numpy.empty(prob_npy.size, dtype=numpy.float32)
        fd_mean_grad = numpy.empty(mean_npy.size, dtype=numpy.float32)
        fd_var_grad = numpy.empty(var_npy.size, dtype=numpy.float32)
        prob_delta = numpy.zeros(prob_npy.size, dtype=numpy.float32)
        mean_delta = numpy.zeros(mean_npy.size, dtype=numpy.float32)
        var_delta = numpy.zeros(var_npy.size, dtype=numpy.float32)
        for i in range(prob_npy.size):
            prob_delta[i] = eps
            fd_prob_grad[i] = (logmog(prob=prob_npy + prob_delta.reshape(prob_npy.shape), mean=mean_npy, var=var_npy, score=score_npy,
                                     sample=sample_npy) - base_loglikelihood)/eps
            prob_delta[i] = 0
        for i in range(mean_npy.size):
            mean_delta[i] = eps
            fd_mean_grad[i] = (logmog(prob=prob_npy, mean=mean_npy + mean_delta.reshape(mean_npy.shape), var=var_npy,
                                      score=score_npy,
                                      sample=sample_npy) - base_loglikelihood) / eps
            mean_delta[i] = 0
        for i in range(var_npy.size):
            var_delta[i] = eps
            fd_var_grad[i] = (logmog(prob=prob_npy, mean=mean_npy, var=var_npy + var_delta.reshape(var_npy.shape),
                                     score=score_npy,
                                     sample=sample_npy) - base_loglikelihood) / eps
            var_delta[i] = 0
        fd_prob_grad = fd_prob_grad.reshape(prob_npy.shape)
        fd_mean_grad = fd_mean_grad.reshape(mean_npy.shape)
        fd_var_grad = fd_var_grad.reshape(var_npy.shape)
        return fd_prob_grad, fd_mean_grad, fd_var_grad
    fd_prob_grad, fd_mean_grad, fd_var_grad = fd_grad()
    # print 'fd_prob_grad:', fd_prob_grad
    # print 'fd_mean_grad:', fd_mean_grad
    # print 'fd_var_grad:', fd_var_grad
    op_prob_grad = net.executor_pool.inputs_grad_dict.values()[0]['prob'].asnumpy()
    op_mean_grad = net.executor_pool.inputs_grad_dict.values()[0]['mean'].asnumpy()
    op_var_grad = net.executor_pool.inputs_grad_dict.values()[0]['var'].asnumpy()
    # print 'op_prob_grad:', op_prob_grad
    # print 'op_mean_grad:', op_mean_grad
    # print 'op_var_grad:', op_var_grad
    print 'prob_grad_diff:', numpy.square(op_prob_grad - fd_prob_grad).sum()
    print 'mean_grad_diff:', numpy.square(op_mean_grad - fd_mean_grad).sum()
    print 'var_grad_diff:', numpy.square(op_var_grad - fd_var_grad).sum()
예제 #13
0
                               height=rows,
                               width=cols),
                    ctx=ctx)

roi_arr = nd.array(load_roi(path=path + "\\groundtruth.txt",
                            num=batch_size,
                            height=rows,
                            width=cols),
                   ctx=ctx)

print data_arr.shape
print roi_arr.shape
data_shapes = {'data': (batch_size, 3, rows, cols), 'roi': (batch_size, 4)}

glimpse_test_net = Base(data_shapes=data_shapes,
                        sym=net,
                        name='GlimpseTest',
                        ctx=ctx)

start = time.time()
out_imgs = glimpse_test_net.forward(batch_size=batch_size,
                                    data=data_arr,
                                    roi=roi_arr)[0].asnumpy()
end = time.time()
print 'Time:', end - start
print out_imgs.shape
for i in range(batch_size):
    for j in range(depth):
        r, g, b = cv2.split(
            numpy.rollaxis(out_imgs[i, j * 3:(j + 1) * 3], 0, 3))
        reshaped_img = cv2.merge([b, g, r])
        cv2.imshow('image', reshaped_img / 255.0)
예제 #14
0
#data_shapes.pop('MemoryHandler:memory_init:lstm0_c', None)
#init_memory_data.pop('MemoryHandler:memory_init:lstm0_c', None)
#data_shapes.pop('MemoryHandler:memory_init:lstm0_h', None)
#data_shapes.pop('MemoryHandler:memory_init:lstm1_c', None)
#init_memory_data.pop('MemoryHandler:memory_init:lstm1_c', None)
#data_shapes.pop('MemoryHandler:memory_init:lstm1_h', None)

print 'data_shapes:', data_shapes
print memory_to_sym_dict(memory)

# net = Base(sym=mx.symbol.Group(symbols=block_all(memory_to_sym_dict(memory).values())),
#            data_shapes=data_shapes)
print sym_out.keys()
net = Base(sym=mx.symbol.Group(sym_out.values() + [
    pred_center, pred_size, memory.numerators, memory.denominators,
    memory.status.counter, memory.status.visiting_timestamp
]),
           data_shapes=data_shapes)

net.print_stat()
perception_handler.set_params(net.params)

constant_inputs = OrderedDict()
constant_inputs['init_write_control_flag'] = 2
constant_inputs['update_factor'] = 0.2
constant_inputs.update(init_memory_data)
constant_inputs.update(init_attention_lstm_data)

seq_images, seq_rois = tracking_iterator.sample(length=sample_length)
additional_inputs = OrderedDict()
additional_inputs["data_images"] = seq_images
예제 #15
0
denominator = mx.symbol.SumChannel(denominator)
denominator = mx.symbol.BroadcastChannel(data=denominator + regularizer,
                                         dim=1,
                                         size=channel_size)
scores = mx.symbol.ComplexHadamard(
    numerator / denominator, mx.symbol.FFT2D(second_feature, batchsize=64))
scores = mx.symbol.IFFT2D(data=scores,
                          output_shape=(numpy.int32(rows), numpy.int32(cols)),
                          batchsize=64)

data_shapes = {
    'feature': (1, channel_size, rows, cols),
    'second_feature': (1, channel_size, rows, cols),
    'gaussian_map': (1, 1, rows, cols)
}
net = Base(data_shapes=data_shapes, sym=scores)
outputs = net.forward(
    data_shapes=data_shapes,
    feature=numpy.rollaxis(embedding_data['xl'], 2).reshape(
        (1, channel_size, rows, cols)),
    second_feature=numpy.rollaxis(joint_embedding_data['xt'], 2).reshape(
        (1, channel_size, rows, cols)),
    gaussian_map=numpy.real(numpy.fft.ifft2(embedding_data['yf'])).reshape(
        (1, 1, rows, cols)))
for output in outputs:
    print output.shape
    for i in range(channel_size):
        score = output.asnumpy()[0, i, :, :]
        cv2.imshow('image', score / score.max())
        cv2.waitKey()
        score = numpy.rollaxis(joint_embedding_data['score_map'], 2)[i]
예제 #16
0
파일: copy_task.py 프로젝트: flyers/Arena
vis = PLTVisualizer()

max_input_seq_len = max_length * 2 + 2
max_output_seq_len = max_length
sym = sym_gen(max_length)
net = Base(data_shapes={
    'data': (max_input_seq_len, batch_size, data_dim),
    'target': (max_output_seq_len, batch_size, data_dim),
    'init_memory': (memory_size, memory_state_dim),
    'init_read_content': (num_reads, memory_state_dim),
    'NTM->read_head:init_focus': (num_reads, memory_size),
    'NTM->write_head:init_focus': (num_writes, memory_size),
    'controller->layer0:init_h': (control_state_dim, ),
    'controller->layer0:init_c': (control_state_dim, )
},
           sym_gen=sym_gen,
           learn_init_keys=[
               'init_memory', 'init_read_content', 'NTM->read_head:init_focus',
               'NTM->write_head:init_focus', 'controller->layer0:init_h',
               'controller->layer0:init_c'
           ],
           default_bucket_kwargs={'seqlen': max_length},
           initializer=NTMInitializer(factor_type="in",
                                      rnd_type="gaussian",
                                      magnitude=2),
           ctx=mx.gpu())
net.print_stat()
###init_memory_npy = numpy.tanh(numpy.random.normal(size=(batch_size, memory_size, memory_state_dim)))
# init_memory_npy = numpy.zeros((batch_size, memory_size, memory_state_dim), dtype=numpy.float32) + 0.1
# init_read_focus_npy = numpy.random.randint(0, memory_size, size=(batch_size, num_reads))
# init_read_focus_npy = npy_softmax(npy_onehot(init_read_focus_npy, num=memory_size), axis=2)
예제 #17
0
env = CartpoleSwingupEnv()
action_dimension = 1
state_dimension = 4

data_shapes = {
    'data': (batch_size, state_dimension),
    'policy_score': (batch_size, ),
    'policy_backward_action': (batch_size, action_dimension),
    'critic_label': (batch_size, ),
    'var': (batch_size, action_dimension),
}
sym = actor_critic_policy_sym(action_dimension)
net = Base(data_shapes=data_shapes,
           sym_gen=sym,
           name='ACNet',
           initializer=mx.initializer.Xavier(rnd_type='gaussian',
                                             factor_type='avg',
                                             magnitude=1.0),
           ctx=ctx)
lr_scheduler = FactorScheduler(500, 0.1)
if args.optimizer == 'sgd':
    optimizer = mx.optimizer.create(name='sgd',
                                    learning_rate=args.lr,
                                    lr_scheduler=lr_scheduler,
                                    momentum=0.9,
                                    clip_gradient=None,
                                    rescale_grad=1.0,
                                    wd=0.)
elif args.optimizer == 'adam':
    optimizer = mx.optimizer.create(name='adam',
                                    learning_rate=args.lr,
예제 #18
0
import numpy
import cv2

attention_size = mx.symbol.Variable('attention_size')
object_size = mx.symbol.Variable('object_size')

hann_map_op = HannWindowGeneratorOp(rows=64, cols=64)

map_fft = gaussian_map_fft(attention_size=attention_size,
                           object_size=object_size,
                           sigma_factor=10,
                           rows=64,
                           cols=64)
map_recons = mx.symbol.IFFT2D(data=map_fft, output_shape=(64, 64))
map_fft = mx.symbol.BlockGrad(map_fft)
data_shapes = {'attention_size': (1, 2), 'object_size': (1, 2)}
net = Base(sym=map_recons, data_shapes=data_shapes)
output = net.forward(data_shapes=data_shapes,
                     attention_size=numpy.array([[0.5, 0.5]]),
                     object_size=numpy.array([[0.3, 0.3]]))[0].asnumpy()
print output.shape
cv2.imshow('image', output[0, 0, :, :])
cv2.waitKey()

hann_map = hann_map_op()
net_hann = Base(sym=hann_map, data_shapes=dict())
output = net_hann.forward(data_shapes=dict())[0].asnumpy()
print output.shape
cv2.imshow('image', output[0, 0, :, :])
cv2.waitKey()
예제 #19
0
def main():
    parser = argparse.ArgumentParser(
        description='Script to test the trained network on a game.')
    parser.add_argument('-r',
                        '--rom',
                        required=False,
                        type=str,
                        default=os.path.join('arena', 'games', 'roms',
                                             'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-d',
                        '--dir-path',
                        required=False,
                        type=str,
                        default='',
                        help='Directory path of the model files.')
    parser.add_argument('-m',
                        '--model_prefix',
                        required=True,
                        type=str,
                        default='QNet',
                        help='Prefix of the saved model file.')
    parser.add_argument('-t',
                        '--test-steps',
                        required=False,
                        type=int,
                        default=125000,
                        help='Test steps.')
    parser.add_argument(
        '-c',
        '--ctx',
        required=False,
        type=str,
        default='gpu',
        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument(
        '-e',
        '--epoch-range',
        required=False,
        type=str,
        default='22',
        help='Epochs to run testing. E.g `-e 0,80`, `-e 0,80,2`')
    parser.add_argument('-v',
                        '--visualization',
                        required=False,
                        type=int,
                        default=0,
                        help='Visualize the runs.')
    parser.add_argument('--symbol',
                        required=False,
                        type=str,
                        default="nature",
                        help='type of network, nature or nips')
    args, unknown = parser.parse_known_args()
    max_start_nullops = 30
    holdout_size = 3200
    replay_memory_size = 1000000
    exploartion = 0.05
    history_length = 4
    rows = 84
    cols = 84
    ctx = re.findall('([a-z]+)(\d*)', args.ctx)
    ctx = [(device, int(num)) if len(num) > 0 else (device, 0)
           for device, num in ctx]
    q_ctx = mx.Context(*ctx[0])
    minibatch_size = 32
    epoch_range = [int(n) for n in args.epoch_range.split(',')]
    epochs = range(*epoch_range)

    game = AtariGame(rom_path=args.rom,
                     history_length=history_length,
                     resize_mode='scale',
                     resized_rows=rows,
                     replay_start_size=4,
                     resized_cols=cols,
                     max_null_op=max_start_nullops,
                     replay_memory_size=replay_memory_size,
                     death_end_episode=False,
                     display_screen=args.visualization)

    if not args.visualization:
        holdout_samples = collect_holdout_samples(game,
                                                  sample_num=holdout_size)
    action_num = len(game.action_set)
    data_shapes = {
        'data': (minibatch_size, history_length) + (rows, cols),
        'dqn_action': (minibatch_size, ),
        'dqn_reward': (minibatch_size, )
    }
    if args.symbol == "nature":
        dqn_sym = dqn_sym_nature(action_num)
    elif args.symbol == "nips":
        dqn_sym = dqn_sym_nips(action_num)
    else:
        raise NotImplementedError
    qnet = Base(data_shapes=data_shapes,
                sym_gen=dqn_sym,
                name=args.model_prefix,
                ctx=q_ctx)

    for epoch in epochs:
        qnet.load_params(name=args.model_prefix,
                         dir_path=args.dir_path,
                         epoch=epoch)
        if not args.visualization:
            avg_q_score = calculate_avg_q(holdout_samples, qnet)
            avg_reward = calculate_avg_reward(game, qnet, args.test_steps,
                                              exploartion)
            print("Epoch:%d Avg Reward: %f, Avg Q Score:%f" %
                  (epoch, avg_reward, avg_q_score))
        else:
            avg_reward = calculate_avg_reward(game, qnet, args.test_steps,
                                              exploartion)
            print("Epoch:%d Avg Reward: %f" % (epoch, avg_reward))
예제 #20
0
minibatch_size = 20
layer_number = 2
net = build_recurrent_sym(time_step, layer_number)
print net.list_arguments()
data_shapes = dict([("data",
                     (minibatch_size, 4))] + [("lstm%d_init_c" % i,
                                               (minibatch_size, 50))
                                              for i in range(layer_number)] +
                   [("lstm%d_init_h" % i, (minibatch_size, 50))
                    for i in range(layer_number)] +
                   [('policy_t%d_score' % i, (minibatch_size, ))
                    for i in range(time_step)])
print data_shapes
qnet = Base(data_shapes=data_shapes,
            sym=net,
            name='PolicyNet',
            initializer=mx.initializer.Xavier(factor_type="in", magnitude=1.0),
            ctx=mx.gpu())

optimizer = mx.optimizer.create(name='sgd',
                                learning_rate=0.00001,
                                clip_gradient=None,
                                rescale_grad=1.0 / minibatch_size,
                                wd=0.00001)
updater = mx.optimizer.get_updater(optimizer)
qnet.print_stat()
baseline = numpy.zeros((time_step, ))
decay_factor = 0.5
for epoch in range(10000):
    data = [("data", numpy.random.rand(minibatch_size, 4))]
    data_ndarray = {k: nd.array(v, ctx=mx.gpu()) for k, v in data}
예제 #21
0
def main():
    parser = argparse.ArgumentParser(
        description='Script to test the trained network on a game.')
    parser.add_argument('-r',
                        '--rom',
                        required=False,
                        type=str,
                        default=os.path.join('arena', 'games', 'roms',
                                             'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-v',
                        '--visualization',
                        required=False,
                        type=int,
                        default=0,
                        help='Visualize the runs.')
    parser.add_argument('--lr',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Learning rate of the AdaGrad optimizer')
    parser.add_argument('--eps',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Eps of the AdaGrad optimizer')
    parser.add_argument('--rms-decay',
                        required=False,
                        type=float,
                        default=0.95,
                        help='Decay rate of the RMSProp')
    parser.add_argument('--clip-gradient',
                        required=False,
                        type=float,
                        default=None,
                        help='Clip threshold of the AdaGrad optimizer')
    parser.add_argument('--double-q',
                        required=False,
                        type=bool,
                        default=False,
                        help='Use Double DQN')
    parser.add_argument('--wd',
                        required=False,
                        type=float,
                        default=0.0,
                        help='Weight of the L2 Regularizer')
    parser.add_argument(
        '-c',
        '--ctx',
        required=False,
        type=str,
        default=None,
        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument('-d',
                        '--dir-path',
                        required=False,
                        type=str,
                        default='',
                        help='Saving directory of model files.')
    parser.add_argument(
        '--start-eps',
        required=False,
        type=float,
        default=1.0,
        help='Eps of the epsilon-greedy policy at the beginning')
    parser.add_argument('--replay-start-size',
                        required=False,
                        type=int,
                        default=50000,
                        help='The step that the training starts')
    parser.add_argument(
        '--kvstore-update-period',
        required=False,
        type=int,
        default=16,
        help='The period that the worker updates the parameters from the sever'
    )
    parser.add_argument(
        '--kv-type',
        required=False,
        type=str,
        default=None,
        help=
        'type of kvstore, default will not use kvstore, could also be dist_async'
    )
    parser.add_argument('--optimizer',
                        required=False,
                        type=str,
                        default="adagrad",
                        help='type of optimizer')
    parser.add_argument('--nactor',
                        required=False,
                        type=int,
                        default=16,
                        help='number of actor')
    parser.add_argument('--exploration-period',
                        required=False,
                        type=int,
                        default=4000000,
                        help='length of annealing of epsilon greedy policy')
    parser.add_argument('--replay-memory-size',
                        required=False,
                        type=int,
                        default=100,
                        help='size of replay memory')
    parser.add_argument('--single-batch-size',
                        required=False,
                        type=int,
                        default=5,
                        help='batch size for every actor')
    parser.add_argument('--symbol',
                        required=False,
                        type=str,
                        default="nature",
                        help='type of network, nature or nips')
    parser.add_argument('--sample-policy',
                        required=False,
                        type=str,
                        default="recent",
                        help='minibatch sampling policy, recent or random')
    parser.add_argument('--epoch-num',
                        required=False,
                        type=int,
                        default=50,
                        help='number of epochs')
    parser.add_argument('--param-update-period',
                        required=False,
                        type=int,
                        default=5,
                        help='Parameter update period')
    parser.add_argument('--resize-mode',
                        required=False,
                        type=str,
                        default="scale",
                        help='Resize mode, scale or crop')
    parser.add_argument('--eps-update-period',
                        required=False,
                        type=int,
                        default=8000,
                        help='eps greedy policy update period')
    parser.add_argument('--server-optimizer',
                        required=False,
                        type=str,
                        default="easgd",
                        help='type of server optimizer')
    parser.add_argument('--nworker',
                        required=False,
                        type=int,
                        default=1,
                        help='number of kv worker')
    parser.add_argument('--easgd-alpha',
                        required=False,
                        type=float,
                        default=0.01,
                        help='easgd alpha')
    args, unknown = parser.parse_known_args()
    logging.info(str(args))

    if args.dir_path == '':
        rom_name = os.path.splitext(os.path.basename(args.rom))[0]
        time_str = time.strftime("%m%d_%H%M_%S", time.localtime())
        args.dir_path = ('dqn-%s-%d_' % (rom_name,int(args.lr*10**5)))+time_str \
                        + "_" + os.environ.get('DMLC_TASK_ID')
        logging.info("saving to dir: " + args.dir_path)
    if args.ctx == None:
        args.ctx = os.environ.get('CTX')
    logging.info("Context: %s" % args.ctx)
    ctx = re.findall('([a-z]+)(\d*)', args.ctx)
    ctx = [(device, int(num)) if len(num) > 0 else (device, 0)
           for device, num in ctx]

    # Async verision
    nactor = args.nactor
    param_update_period = args.param_update_period

    replay_start_size = args.replay_start_size
    max_start_nullops = 30
    replay_memory_size = args.replay_memory_size
    history_length = 4
    rows = 84
    cols = 84
    q_ctx = mx.Context(*ctx[0])
    games = []
    for g in range(nactor):
        games.append(
            AtariGame(rom_path=args.rom,
                      resize_mode=args.resize_mode,
                      replay_start_size=replay_start_size,
                      resized_rows=rows,
                      resized_cols=cols,
                      max_null_op=max_start_nullops,
                      replay_memory_size=replay_memory_size,
                      display_screen=args.visualization,
                      history_length=history_length))

    ##RUN NATURE
    freeze_interval = 40000 / nactor
    freeze_interval /= param_update_period
    epoch_num = args.epoch_num
    steps_per_epoch = 4000000 / nactor
    discount = 0.99
    save_screens = False
    eps_start = numpy.ones((3, )) * args.start_eps
    eps_min = numpy.array([0.1, 0.01, 0.5])
    eps_decay = (eps_start - eps_min) / (args.exploration_period / nactor)
    eps_curr = eps_start
    eps_id = numpy.zeros((nactor, ))
    eps_update_period = args.eps_update_period
    eps_update_count = numpy.zeros((nactor, ))

    single_batch_size = args.single_batch_size
    minibatch_size = nactor * single_batch_size
    action_num = len(games[0].action_set)
    data_shapes = {
        'data': (minibatch_size, history_length) + (rows, cols),
        'dqn_action': (minibatch_size, ),
        'dqn_reward': (minibatch_size, )
    }

    if args.symbol == "nature":
        dqn_sym = dqn_sym_nature(action_num)
    elif args.symbol == "nips":
        dqn_sym = dqn_sym_nips(action_num)
    else:
        raise NotImplementedError
    qnet = Base(data_shapes=data_shapes,
                sym=dqn_sym,
                name='QNet',
                initializer=DQNInitializer(factor_type="in"),
                ctx=q_ctx)
    target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx)

    if args.optimizer == "adagrad":
        optimizer = mx.optimizer.create(name=args.optimizer,
                                        learning_rate=args.lr,
                                        eps=args.eps,
                                        clip_gradient=args.clip_gradient,
                                        rescale_grad=1.0,
                                        wd=args.wd)
    elif args.optimizer == "rmsprop" or args.optimizer == "rmspropnoncentered":
        optimizer = mx.optimizer.create(name=args.optimizer,
                                        learning_rate=args.lr,
                                        eps=args.eps,
                                        clip_gradient=args.clip_gradient,
                                        gamma1=args.rms_decay,
                                        gamma2=0,
                                        rescale_grad=1.0,
                                        wd=args.wd)
        lr_decay = (args.lr - 0) / (steps_per_epoch * epoch_num /
                                    param_update_period)

    # Create kvstore
    use_easgd = False
    if args.kv_type != None:
        kvType = args.kv_type
        kv = kvstore.create(kvType)
        #Initialize kvstore
        for idx, v in enumerate(qnet.params.values()):
            kv.init(idx, v)
        if args.server_optimizer == "easgd":
            use_easgd = True
            easgd_beta = 0.9
            easgd_alpha = args.easgd_alpha
            server_optimizer = mx.optimizer.create(name="ServerEasgd",
                                                   learning_rate=easgd_alpha)
            easgd_eta = 0.00025
            central_weight = OrderedDict([(n, v.copyto(q_ctx))
                                          for n, v in qnet.params.items()])
            kv.set_optimizer(server_optimizer)
            updater = mx.optimizer.get_updater(optimizer)
        else:
            kv.set_optimizer(optimizer)
        kvstore_update_period = args.kvstore_update_period
        npy_rng = numpy.random.RandomState(123456 + kv.rank)
    else:
        updater = mx.optimizer.get_updater(optimizer)

    qnet.print_stat()
    target_qnet.print_stat()

    states_buffer_for_act = numpy.zeros(
        (nactor, history_length) + (rows, cols), dtype='uint8')
    states_buffer_for_train = numpy.zeros(
        (minibatch_size, history_length + 1) + (rows, cols), dtype='uint8')
    next_states_buffer_for_train = numpy.zeros(
        (minibatch_size, history_length) + (rows, cols), dtype='uint8')
    actions_buffer_for_train = numpy.zeros((minibatch_size, ), dtype='uint8')
    rewards_buffer_for_train = numpy.zeros((minibatch_size, ), dtype='float32')
    terminate_flags_buffer_for_train = numpy.zeros((minibatch_size, ),
                                                   dtype='bool')
    # Begin Playing Game
    training_steps = 0
    total_steps = 0
    ave_fps = 0
    ave_loss = 0
    time_for_info = time.time()
    parallel_executor = concurrent.futures.ThreadPoolExecutor(nactor)
    for epoch in xrange(epoch_num):
        # Run Epoch
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        #
        for g, game in enumerate(games):
            game.start()
            game.begin_episode()
            eps_rand = npy_rng.rand()
            if eps_rand < 0.4:
                eps_id[g] = 0
            elif eps_rand < 0.7:
                eps_id[g] = 1
            else:
                eps_id[g] = 2
        episode_stats = [EpisodeStat() for i in range(len(games))]
        while steps_left > 0:
            for g, game in enumerate(games):
                if game.episode_terminate:
                    episode += 1
                    epoch_reward += game.episode_reward
                    if args.kv_type != None:
                        info_str = "Node[%d]: " % kv.rank
                    else:
                        info_str = ""
                    info_str += "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                                % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward,
                                   ave_fps, (eps_curr[eps_id[g]]))
                    info_str += ", Avg Loss:%f" % ave_loss
                    if episode_stats[g].episode_action_step > 0:
                        info_str += ", Avg Q Value:%f/%d" % (
                            episode_stats[g].episode_q_value /
                            episode_stats[g].episode_action_step,
                            episode_stats[g].episode_action_step)
                    if g == 0: logging.info(info_str)
                    if eps_update_count[g] * eps_update_period < total_steps:
                        eps_rand = npy_rng.rand()
                        if eps_rand < 0.4:
                            eps_id[g] = 0
                        elif eps_rand < 0.7:
                            eps_id[g] = 1
                        else:
                            eps_id[g] = 2
                        eps_update_count[g] += 1
                    game.begin_episode(steps_left)
                    episode_stats[g] = EpisodeStat()

            if total_steps > history_length:
                for g, game in enumerate(games):
                    current_state = game.current_state()
                    states_buffer_for_act[g] = current_state

            states = nd.array(states_buffer_for_act, ctx=q_ctx) / float(255.0)

            qval_npy = qnet.forward(batch_size=nactor,
                                    data=states)[0].asnumpy()
            actions_that_max_q = numpy.argmax(qval_npy, axis=1)
            actions = [0] * nactor
            for g, game in enumerate(games):
                # 1. We need to choose a new action based on the current game status
                if games[g].state_enabled and games[
                        g].replay_memory.sample_enabled:
                    do_exploration = (npy_rng.rand() < eps_curr[eps_id[g]])
                    if do_exploration:
                        action = npy_rng.randint(action_num)
                    else:
                        # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                        # We can simply stack the current_state() of gaming instances and give prediction for all of them
                        # We need to wait after calling calc_score(.), which makes the program slow
                        # TODO Profiling the speed of this part!
                        action = actions_that_max_q[g]
                        episode_stats[g].episode_q_value += qval_npy[g, action]
                        episode_stats[g].episode_action_step += 1
                else:
                    action = npy_rng.randint(action_num)
                actions[g] = action
            # t0=time.time()
            for ret in parallel_executor.map(play_game, zip(games, actions)):
                pass
            # t1=time.time()
            # logging.info("play time: %f" % (t1-t0))
            eps_curr = numpy.maximum(eps_curr - eps_decay, eps_min)
            total_steps += 1
            steps_left -= 1
            if total_steps % 100 == 0:
                this_time = time.time()
                ave_fps = (100 / (this_time - time_for_info))
                time_for_info = this_time

            # 3. Update our Q network if we can start sampling from the replay memory
            #    Also, we update every `update_interval`
            if total_steps > minibatch_size and \
                total_steps % (param_update_period) == 0 and \
                games[-1].replay_memory.sample_enabled:
                if use_easgd and training_steps % kvstore_update_period == 0:
                    for paramIndex in range(len(qnet.params)):
                        k = qnet.params.keys()[paramIndex]
                        kv.pull(paramIndex,
                                central_weight[k],
                                priority=-paramIndex)
                        qnet.params[k][:] -= easgd_alpha * (qnet.params[k] -
                                                            central_weight[k])
                        kv.push(paramIndex,
                                qnet.params[k],
                                priority=-paramIndex)
                # 3.1 Draw sample from the replay_memory
                for g, game in enumerate(games):
                    episode_stats[g].episode_update_step += 1
                    nsample = single_batch_size
                    i0 = (g * nsample)
                    i1 = (g + 1) * nsample
                    if args.sample_policy == "recent":
                        action, reward, terminate_flag=game.replay_memory.sample_last(batch_size=nsample,\
                            states=states_buffer_for_train,offset=i0)
                    elif args.sample_policy == "random":
                        action, reward, terminate_flag=game.replay_memory.sample_inplace(batch_size=nsample,\
                            states=states_buffer_for_train,offset=i0)
                    actions_buffer_for_train[i0:i1] = action
                    rewards_buffer_for_train[i0:i1] = reward
                    terminate_flags_buffer_for_train[i0:i1] = terminate_flag
                states = nd.array(states_buffer_for_train[:, :-1],
                                  ctx=q_ctx) / float(255.0)
                next_states = nd.array(states_buffer_for_train[:, 1:],
                                       ctx=q_ctx) / float(255.0)
                actions = nd.array(actions_buffer_for_train, ctx=q_ctx)
                rewards = nd.array(rewards_buffer_for_train, ctx=q_ctx)
                terminate_flags = nd.array(terminate_flags_buffer_for_train,
                                           ctx=q_ctx)

                # 3.2 Use the target network to compute the scores and
                #     get the corresponding target rewards
                if not args.double_q:
                    target_qval = target_qnet.forward(
                        batch_size=minibatch_size, data=next_states)[0]
                    target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                            nd.argmax_channel(target_qval))\
                                       * (1.0 - terminate_flags) * discount
                else:
                    target_qval = target_qnet.forward(
                        batch_size=minibatch_size, data=next_states)[0]
                    qval = qnet.forward(batch_size=minibatch_size,
                                        data=next_states)[0]

                    target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                            nd.argmax_channel(qval))\
                                       * (1.0 - terminate_flags) * discount

                outputs = qnet.forward(batch_size=minibatch_size,
                                       is_train=True,
                                       data=states,
                                       dqn_action=actions,
                                       dqn_reward=target_rewards)
                qnet.backward(batch_size=minibatch_size)

                if args.kv_type is None or use_easgd:
                    qnet.update(updater=updater)
                else:
                    update_on_kvstore(kv, qnet.params, qnet.params_grad)

                # 3.3 Calculate Loss
                diff = nd.abs(
                    nd.choose_element_0index(outputs[0], actions) -
                    target_rewards)
                quadratic_part = nd.clip(diff, -1, 1)
                loss = (0.5 * nd.sum(nd.square(quadratic_part)) +
                        nd.sum(diff - quadratic_part)).asscalar()
                if ave_loss == 0:
                    ave_loss = loss
                else:
                    ave_loss = 0.95 * ave_loss + 0.05 * loss

                # 3.3 Update the target network every freeze_interval
                # (We can do annealing instead of hard copy)
                if training_steps % freeze_interval == 0:
                    qnet.copy_params_to(target_qnet)

                if args.optimizer == "rmsprop" or args.optimizer == "rmspropnoncentered":
                    optimizer.lr -= lr_decay

                if save_screens and training_steps % (
                        60 * 60 * 2 / param_update_period) == 0:
                    logging.info("saving screenshots")
                    for g in range(nactor):
                        screen = states_buffer_for_train[(
                            g * single_batch_size), -2, :, :].reshape(
                                states_buffer_for_train.shape[2:])
                        cv2.imwrite("screen_" + str(g) + ".png", screen)
                training_steps += 1

        end = time.time()
        fps = steps_per_epoch / (end - start)
        qnet.save_params(dir_path=args.dir_path, epoch=epoch)
        if args.kv_type != None:
            logging.info(
                "Node[%d]: Epoch:%d, FPS:%f, Avg Reward: %f/%d" %
                (kv.rank, epoch, fps, epoch_reward / float(episode), episode))
        else:
            logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d" %
                         (epoch, fps, epoch_reward / float(episode), episode))
예제 #22
0
def test_mxnet_binary(test_operation, typ):
    if 'div' == test_operation:
        numpy_outf = complex_div
        numpy_gradf = complex_div_grad
        if typ == 'rc':
            test_sym = mx.symbol.complex_div_rc
        elif typ == 'cc':
            test_sym = mx.symbol.complex_div_cc
        elif typ == 'cr':
            test_sym = mx.symbol.complex_div_cr
    else:
        numpy_outf = complex_mul
        numpy_gradf = complex_mul_grad
        if typ == 'rc':
            test_sym = mx.symbol.complex_mul_rc
        elif typ == 'cc':
            test_sym = mx.symbol.complex_mul_cc
        elif typ == 'cr':
            test_sym = mx.symbol.complex_mul_cr
    a = mx.symbol.Variable('a')
    b = mx.symbol.Variable('b')
    c = test_sym(a, b)
    base_complex_shape = (10, 10, 6)
    base_real_shape = (10, 10, 3)
    if 'cc' == typ:
        data_shapes = {'a': base_complex_shape, 'b': base_complex_shape}
        a_complex_npy = numpy.random.rand(*base_real_shape) + \
                        numpy.random.rand(*base_real_shape) * 1j
        b_complex_npy = numpy.random.rand(*base_real_shape) + \
                        numpy.random.rand(*base_real_shape) * 1j
        a_npy = numpy.empty(data_shapes['a'])
        b_npy = numpy.empty(data_shapes['b'])
        a_npy[:, :, ::2] = a_complex_npy.real
        a_npy[:, :, 1::2] = a_complex_npy.imag
        b_npy[:, :, ::2] = b_complex_npy.real
        b_npy[:, :, 1::2] = b_complex_npy.imag
        net = Base(data_shapes=data_shapes, sym=c)
        outputs = net.forward(a=a_npy, b=b_npy)
        out_grad = numpy.random.rand(*data_shapes['a'])
        print numpy.square(
            outputs[0].asnumpy()[:, :, ::2] -
            numpy_outf(a_complex_npy, b_complex_npy).real).sum()
        print numpy.square(
            outputs[0].asnumpy()[:, :, 1::2] -
            numpy_outf(a_complex_npy, b_complex_npy).imag).sum()
        net.backward(out_grads=[nd.array(out_grad, ctx=mx.gpu())])
        a_grad_npy, b_grad_npy = numpy_gradf(
            out_grad[:, :, ::2] + out_grad[:, :, 1::2] * 1j, a_complex_npy,
            b_complex_npy)
        print numpy.square(net.exe.grad_dict['a'].asnumpy()[:, :, ::2] -
                           a_grad_npy.real).sum()
        print numpy.square(net.exe.grad_dict['a'].asnumpy()[:, :, 1::2] -
                           a_grad_npy.imag).sum()
        print numpy.square(net.exe.grad_dict['b'].asnumpy()[:, :, ::2] -
                           b_grad_npy.real).sum()
        print numpy.square(net.exe.grad_dict['b'].asnumpy()[:, :, 1::2] -
                           b_grad_npy.imag).sum()
    elif 'rc' == typ:
        data_shapes = {'a': base_real_shape, 'b': base_complex_shape}
        a_complex_npy = numpy.random.rand(*base_real_shape)
        b_complex_npy = numpy.random.rand(*base_real_shape) + \
                        numpy.random.rand(*base_real_shape) * 1j
        a_npy = numpy.empty(data_shapes['a'])
        b_npy = numpy.empty(data_shapes['b'])
        a_npy = a_complex_npy
        b_npy[:, :, ::2] = b_complex_npy.real
        b_npy[:, :, 1::2] = b_complex_npy.imag
        net = Base(data_shapes=data_shapes, sym=c)
        outputs = net.forward(a=a_npy, b=b_npy)
        out_grad = numpy.random.rand(*data_shapes['b'])
        print numpy.square(
            outputs[0].asnumpy()[:, :, ::2] -
            numpy_outf(a_complex_npy, b_complex_npy).real).sum()
        print numpy.square(
            outputs[0].asnumpy()[:, :, 1::2] -
            numpy_outf(a_complex_npy, b_complex_npy).imag).sum()
        net.backward(out_grads=[nd.array(out_grad, ctx=mx.gpu())])
        a_grad_npy, b_grad_npy = numpy_gradf(
            out_grad[:, :, ::2] + out_grad[:, :, 1::2] * 1j, a_complex_npy,
            b_complex_npy)
        print numpy.square(net.exe.grad_dict['a'].asnumpy() -
                           a_grad_npy.real).sum()
        print numpy.square(net.exe.grad_dict['b'].asnumpy()[:, :, ::2] -
                           b_grad_npy.real).sum()
        print numpy.square(net.exe.grad_dict['b'].asnumpy()[:, :, 1::2] -
                           b_grad_npy.imag).sum()
    else:
        data_shapes = {'a': base_complex_shape, 'b': base_real_shape}
        a_complex_npy = numpy.random.rand(*base_real_shape) + \
                        numpy.random.rand(*base_real_shape) * 1j
        b_complex_npy = numpy.random.rand(*data_shapes['b'])
        a_npy = numpy.empty(data_shapes['a'])
        b_npy = numpy.empty(data_shapes['b'])
        a_npy[:, :, ::2] = a_complex_npy.real
        a_npy[:, :, 1::2] = a_complex_npy.imag
        b_npy = b_complex_npy.real
        net = Base(data_shapes=data_shapes, sym=c)
        outputs = net.forward(a=a_npy, b=b_npy)
        out_grad = numpy.random.rand(*data_shapes['a'])
        print numpy.square(
            outputs[0].asnumpy()[:, :, ::2] -
            numpy_outf(a_complex_npy, b_complex_npy).real).sum()
        print numpy.square(
            outputs[0].asnumpy()[:, :, 1::2] -
            numpy_outf(a_complex_npy, b_complex_npy).imag).sum()
        net.backward(out_grads=[nd.array(out_grad, ctx=mx.gpu())])
        a_grad_npy, b_grad_npy = numpy_gradf(
            out_grad[:, :, ::2] + out_grad[:, :, 1::2] * 1j, a_complex_npy,
            b_complex_npy)
        print numpy.square(net.exe.grad_dict['a'].asnumpy()[:, :, ::2] -
                           a_grad_npy.real).sum()
        print numpy.square(net.exe.grad_dict['a'].asnumpy()[:, :, 1::2] -
                           a_grad_npy.imag).sum()
        print numpy.square(net.exe.grad_dict['b'].asnumpy() -
                           b_grad_npy.real).sum()
예제 #23
0
파일: test_policy.py 프로젝트: flyers/Arena
def test_lognormal():
    var = mx.symbol.Variable('var')
    data = mx.symbol.Variable('data')
    net_mean = mx.symbol.FullyConnected(data=data,
                                        name='fc_mean_1',
                                        num_hidden=20)
    net_mean = mx.symbol.Activation(data=net_mean,
                                    name='fc_mean_relu_1',
                                    act_type='relu')
    net_mean = mx.symbol.FullyConnected(data=data,
                                        name='fc_mean_2',
                                        num_hidden=20)
    net_mean = mx.symbol.Activation(data=net_mean,
                                    name='fc_mean_relu_2',
                                    act_type='relu')
    net_mean = mx.symbol.FullyConnected(data=net_mean,
                                        name='fc_mean_3',
                                        num_hidden=10)
    net_var = mx.symbol.FullyConnected(data=data,
                                       name='fc_var_1',
                                       num_hidden=10)
    net_var = mx.symbol.Activation(data=net_var,
                                   name='fc_var_softplus_1',
                                   act_type='softrelu')
    net = mx.symbol.Custom(mean=net_mean,
                           var=net_var,
                           name='policy',
                           deterministic=False,
                           entropy_regularization=0.01,
                           op_type='LogNormalPolicy')
    ctx = mx.gpu()
    minibatch_size = 100
    data_shapes = {
        'data': (minibatch_size, 10),
        'policy_score': (minibatch_size, )
    }  #, 'var':(minibatch_size,)}
    qnet = Base(data_shapes=data_shapes,
                sym_gen=net,
                name='PolicyNet',
                initializer=mx.initializer.Xavier(factor_type="in",
                                                  magnitude=1.0),
                ctx=ctx)
    print qnet.internal_sym_names

    lr = 0.01
    lr_scheduler = FactorScheduler(1000, 1.0 / 1.5)
    optimizer = mx.optimizer.create(
        name='sgd',
        learning_rate=lr,  #momentum=0.9,
        clip_gradient=None,
        lr_scheduler=lr_scheduler,
        rescale_grad=1.0,
        wd=0.)
    updater = mx.optimizer.get_updater(optimizer)
    total_iter = 1000000
    stats = numpy.zeros((total_iter, 3), dtype=numpy.float32)
    plt.ion()
    fig, ax = plt.subplots()
    lines, = ax.plot([], [])
    ax.set_autoscaley_on(True)
    baseline = 0
    for i in range(total_iter):
        #    for k, v in qnet.params.items():
        #        print k, v.asnumpy()
        data = numpy.random.randn(minibatch_size, 10)
        means = qnet.compute_internal(sym_name="fc_mean_3_output",
                                      data=data).asnumpy()
        vars = qnet.compute_internal(sym_name="fc_var_softplus_1_output",
                                     data=data).asnumpy()

        outputs = qnet.forward(
            is_train=True,
            data=data)  #, var=0.5*numpy.ones((minibatch_size, )))
        action = outputs[0].asnumpy()
        score = simple_game_multimodal(data, action, 1)
        baseline = baseline - 0.01 * (baseline - score.mean())
        print 'score=', score.mean(), 'err=', numpy.square(
            means -
            data * data).mean(), 'var=', vars.mean(), 'baseline=', baseline
        stats[i] = [
            score.mean(),
            numpy.square(means - data * data).mean(),
            vars.mean()
        ]
        qnet.backward(policy_score=score - baseline)
        norm_clipping(qnet.params_grad, 10)
        qnet.update(updater)
        if i % 10 == 0:
            update_line(lines, fig, ax, i,
                        score.mean())  #numpy.square(means - data*data).mean())
예제 #24
0
def main():
    parser = argparse.ArgumentParser(
        description='Script to test the trained network on a game.')
    parser.add_argument('-r',
                        '--rom',
                        required=False,
                        type=str,
                        default=os.path.join('arena', 'games', 'roms',
                                             'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-v',
                        '--visualization',
                        required=False,
                        type=int,
                        default=0,
                        help='Visualize the runs.')
    parser.add_argument('--lr',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Learning rate of the AdaGrad optimizer')
    parser.add_argument('--eps',
                        required=False,
                        type=float,
                        default=0.01,
                        help='Eps of the AdaGrad optimizer')
    parser.add_argument('--clip-gradient',
                        required=False,
                        type=float,
                        default=None,
                        help='Clip threshold of the AdaGrad optimizer')
    parser.add_argument('--double-q',
                        required=False,
                        type=bool,
                        default=False,
                        help='Use Double DQN')
    parser.add_argument('--wd',
                        required=False,
                        type=float,
                        default=0.0,
                        help='Weight of the L2 Regularizer')
    parser.add_argument(
        '-c',
        '--ctx',
        required=False,
        type=str,
        default='gpu',
        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument('-d',
                        '--dir-path',
                        required=False,
                        type=str,
                        default='',
                        help='Saving directory of model files.')
    parser.add_argument(
        '--start-eps',
        required=False,
        type=float,
        default=1.0,
        help='Eps of the epsilon-greedy policy at the beginning')
    parser.add_argument('--replay-start-size',
                        required=False,
                        type=int,
                        default=50000,
                        help='The step that the training starts')
    parser.add_argument(
        '--kvstore-update-period',
        required=False,
        type=int,
        default=1,
        help='The period that the worker updates the parameters from the sever'
    )
    parser.add_argument(
        '--kv-type',
        required=False,
        type=str,
        default=None,
        help=
        'type of kvstore, default will not use kvstore, could also be dist_async'
    )
    parser.add_argument('--optimizer',
                        required=False,
                        type=str,
                        default="adagrad",
                        help='type of optimizer')
    args = parser.parse_args()

    if args.dir_path == '':
        rom_name = os.path.splitext(os.path.basename(args.rom))[0]
        args.dir_path = 'dqn-%s-lr%g' % (rom_name, args.lr)
    replay_start_size = args.replay_start_size
    max_start_nullops = 30
    replay_memory_size = 1000000
    history_length = 4
    rows = 84
    cols = 84

    ctx = parse_ctx(args.ctx)
    q_ctx = mx.Context(*ctx[0])

    game = AtariGame(rom_path=args.rom,
                     resize_mode='scale',
                     replay_start_size=replay_start_size,
                     resized_rows=rows,
                     resized_cols=cols,
                     max_null_op=max_start_nullops,
                     replay_memory_size=replay_memory_size,
                     display_screen=args.visualization,
                     history_length=history_length)

    ##RUN NATURE
    freeze_interval = 10000
    epoch_num = 200
    steps_per_epoch = 250000
    update_interval = 4
    discount = 0.99

    eps_start = args.start_eps
    eps_min = 0.1
    eps_decay = (eps_start - eps_min) / 1000000
    eps_curr = eps_start
    freeze_interval /= update_interval
    minibatch_size = 32
    action_num = len(game.action_set)

    data_shapes = {
        'data': (minibatch_size, history_length) + (rows, cols),
        'dqn_action': (minibatch_size, ),
        'dqn_reward': (minibatch_size, )
    }
    dqn_sym = dqn_sym_nature(action_num)
    qnet = Base(data_shapes=data_shapes,
                sym_gen=dqn_sym,
                name='QNet',
                initializer=DQNInitializer(factor_type="in"),
                ctx=q_ctx)
    target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx)

    use_easgd = False
    if args.optimizer != "easgd":
        optimizer = mx.optimizer.create(name=args.optimizer,
                                        learning_rate=args.lr,
                                        eps=args.eps,
                                        clip_gradient=args.clip_gradient,
                                        rescale_grad=1.0,
                                        wd=args.wd)
    else:
        use_easgd = True
        easgd_beta = 0.9
        easgd_p = 4
        easgd_alpha = easgd_beta / (args.kvstore_update_period * easgd_p)
        server_optimizer = mx.optimizer.create(name="ServerEASGD",
                                               learning_rate=easgd_alpha)
        easgd_eta = 0.00025
        local_optimizer = mx.optimizer.create(name='adagrad',
                                              learning_rate=args.lr,
                                              eps=args.eps,
                                              clip_gradient=args.clip_gradient,
                                              rescale_grad=1.0,
                                              wd=args.wd)
        central_weight = OrderedDict([(n, nd.zeros(v.shape, ctx=q_ctx))
                                      for n, v in qnet.params.items()])
    # Create KVStore
    if args.kv_type != None:
        kv = kvstore.create(args.kv_type)

        #Initialize KVStore
        for idx, v in enumerate(qnet.params.values()):
            kv.init(idx, v)

        # Set Server optimizer on KVStore
        if not use_easgd:
            kv.set_optimizer(optimizer)
        else:
            kv.set_optimizer(server_optimizer)
            local_updater = mx.optimizer.get_updater(local_optimizer)
        kvstore_update_period = args.kvstore_update_period
        args.dir_path = args.dir_path + "-" + str(kv.rank)
    else:
        updater = mx.optimizer.get_updater(optimizer)

    qnet.print_stat()
    target_qnet.print_stat()

    # Begin Playing Game
    training_steps = 0
    total_steps = 0
    for epoch in xrange(epoch_num):
        # Run Epoch
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        game.start()
        while steps_left > 0:
            # Running New Episode
            episode += 1
            episode_loss = 0.0
            episode_q_value = 0.0
            episode_update_step = 0
            episode_action_step = 0
            time_episode_start = time.time()
            game.begin_episode(steps_left)
            while not game.episode_terminate:
                # 1. We need to choose a new action based on the current game status
                if game.state_enabled and game.replay_memory.sample_enabled:
                    do_exploration = (npy_rng.rand() < eps_curr)
                    eps_curr = max(eps_curr - eps_decay, eps_min)
                    if do_exploration:
                        action = npy_rng.randint(action_num)
                    else:
                        # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                        # We can simply stack the current_state() of gaming instances and give prediction for all of them
                        # We need to wait after calling calc_score(.), which makes the program slow
                        # TODO Profiling the speed of this part!
                        current_state = game.current_state()
                        state = nd.array(
                            current_state.reshape((1, ) + current_state.shape),
                            ctx=q_ctx) / float(255.0)
                        qval_npy = qnet.forward(is_train=False,
                                                data=state)[0].asnumpy()
                        action = numpy.argmax(qval_npy)
                        episode_q_value += qval_npy[0, action]
                        episode_action_step += 1
                else:
                    action = npy_rng.randint(action_num)

                # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times)
                game.play(action)
                total_steps += 1

                # 3. Update our Q network if we can start sampling from the replay memory
                #    Also, we update every `update_interval`
                if total_steps % update_interval == 0 and game.replay_memory.sample_enabled:
                    # 3.1 Draw sample from the replay_memory
                    training_steps += 1
                    episode_update_step += 1
                    states, actions, rewards, next_states, terminate_flags \
                        = game.replay_memory.sample(batch_size=minibatch_size)
                    states = nd.array(states, ctx=q_ctx) / float(255.0)
                    next_states = nd.array(next_states,
                                           ctx=q_ctx) / float(255.0)
                    actions = nd.array(actions, ctx=q_ctx)
                    rewards = nd.array(rewards, ctx=q_ctx)
                    terminate_flags = nd.array(terminate_flags, ctx=q_ctx)

                    # 3.2 Use the target network to compute the scores and
                    #     get the corresponding target rewards
                    if not args.double_q:
                        target_qval = target_qnet.forward(is_train=False,
                                                          data=next_states)[0]
                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(target_qval))\
                                           * (1.0 - terminate_flags) * discount
                    else:
                        target_qval = target_qnet.forward(is_train=False,
                                                          data=next_states)[0]
                        qval = qnet.forward(is_train=False,
                                            data=next_states)[0]

                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(qval))\
                                           * (1.0 - terminate_flags) * discount
                    outputs = qnet.forward(is_train=True,
                                           data=states,
                                           dqn_action=actions,
                                           dqn_reward=target_rewards)
                    qnet.backward()

                    if args.kv_type != None:
                        if use_easgd:
                            if total_steps % kvstore_update_period == 0:
                                for ind, k in enumerate(qnet.params.keys()):
                                    kv.pull(ind,
                                            central_weight[k],
                                            priority=-ind)
                                    qnet.params[k][:] -= easgd_alpha * \
                                                         (qnet.params[k] - central_weight[k])
                                    kv.push(ind, qnet.params[k], priority=-ind)
                            qnet.update(updater=local_updater)
                        else:
                            update_on_kvstore(kv, qnet.params,
                                              qnet.params_grad)
                    else:
                        qnet.update(updater=updater)

                    # 3.3 Calculate Loss
                    diff = nd.abs(
                        nd.choose_element_0index(outputs[0], actions) -
                        target_rewards)
                    quadratic_part = nd.clip(diff, -1, 1)
                    loss = 0.5 * nd.sum(nd.square(quadratic_part)).asnumpy()[0] +\
                           nd.sum(diff - quadratic_part).asnumpy()[0]
                    episode_loss += loss

                    # 3.3 Update the target network every freeze_interval
                    # (We can do annealing instead of hard copy)
                    if training_steps % freeze_interval == 0:
                        qnet.copy_params_to(target_qnet)
            steps_left -= game.episode_step
            time_episode_end = time.time()
            # Update the statistics
            epoch_reward += game.episode_reward
            if args.kv_type != None:
                info_str = "Node[%d]: " % kv.rank
            else:
                info_str = ""
            info_str += "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                        % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward,
                           game.episode_step / (time_episode_end - time_episode_start), eps_curr)
            if episode_update_step > 0:
                info_str += ", Avg Loss:%f/%d" % (
                    episode_loss / episode_update_step, episode_update_step)
            if episode_action_step > 0:
                info_str += ", Avg Q Value:%f/%d" % (
                    episode_q_value / episode_action_step, episode_action_step)
            logging.info(info_str)
        end = time.time()
        fps = steps_per_epoch / (end - start)
        qnet.save_params(dir_path=args.dir_path, epoch=epoch)
        if args.kv_type is not None:
            logging.info(
                "Node[%d]: Epoch:%d, FPS:%f, Avg Reward: %f/%d" %
                (kv.rank, epoch, fps, epoch_reward / float(episode), episode))
        else:
            logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d" %
                         (epoch, fps, epoch_reward / float(episode), episode))
예제 #25
0
파일: kvtest.py 프로젝트: flyers/Arena
def main():
    parser = argparse.ArgumentParser(description='Script to test the trained network on a game.')
    parser.add_argument('-r', '--rom', required=False, type=str,
                        default=os.path.join('arena', 'games', 'roms', 'breakout.bin'),
                        help='Path of the ROM File.')
    parser.add_argument('-v', '--visualization', required=False, type=int, default=0,
                        help='Visualize the runs.')
    parser.add_argument('--lr', required=False, type=float, default=0.01,
                        help='Learning rate of the AdaGrad optimizer')
    parser.add_argument('--eps', required=False, type=float, default=0.01,
                        help='Eps of the AdaGrad optimizer')
    parser.add_argument('--clip-gradient', required=False, type=float, default=None,
                        help='Clip threshold of the AdaGrad optimizer')
    parser.add_argument('--double-q', required=False, type=bool, default=False,
                        help='Use Double DQN')
    parser.add_argument('--wd', required=False, type=float, default=0.0,
                        help='Weight of the L2 Regularizer')
    parser.add_argument('-c', '--ctx', required=False, type=str, default='gpu',
                        help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`')
    parser.add_argument('-d', '--dir-path', required=False, type=str, default='',
                        help='Saving directory of model files.')
    parser.add_argument('--start-eps', required=False, type=float, default=1.0,
                        help='Eps of the epsilon-greedy policy at the beginning')
    parser.add_argument('--replay-start-size', required=False, type=int, default=50000,
                        help='The step that the training starts')
    parser.add_argument('--kvstore-update-period', required=False, type=int, default=1,
                        help='The period that the worker updates the parameters from the sever')
    parser.add_argument('--kv-type', required=False, type=str, default=None,
                        help='type of kvstore, default will not use kvstore, could also be dist_async')
    args, unknown = parser.parse_known_args()
    if args.dir_path == '':
        rom_name = os.path.splitext(os.path.basename(args.rom))[0]
        args.dir_path = 'dqn-%s' % rom_name
    ctx = re.findall('([a-z]+)(\d*)', args.ctx)
    ctx = [(device, int(num)) if len(num) >0 else (device, 0) for device, num in ctx]
    replay_start_size = args.replay_start_size
    max_start_nullops = 30
    replay_memory_size = 1000000
    history_length = 4
    rows = 84
    cols = 84
    q_ctx = mx.Context(*ctx[0])

    game = AtariGame(rom_path=args.rom, resize_mode='scale', replay_start_size=replay_start_size,
                     resized_rows=rows, resized_cols=cols, max_null_op=max_start_nullops,
                     replay_memory_size=replay_memory_size, display_screen=args.visualization,
                     history_length=history_length)

    ##RUN NATURE
    freeze_interval = 10000
    epoch_num = 200
    steps_per_epoch = 250000
    update_interval = 4
    discount = 0.99

    eps_start = args.start_eps
    eps_min = 0.1
    eps_decay = (eps_start - 0.1) / 1000000
    eps_curr = eps_start
    freeze_interval /= update_interval
    minibatch_size = 32
    action_num = len(game.action_set)

    data_shapes = {'data': (minibatch_size, history_length) + (rows, cols),
                   'dqn_action': (minibatch_size,), 'dqn_reward': (minibatch_size,)}
    #optimizer = mx.optimizer.create(name='sgd', learning_rate=args.lr,wd=args.wd)
    optimizer = mx.optimizer.Nop()
    dqn_output_op = DQNOutputNpyOp()
    dqn_sym = dqn_sym_nature(action_num, dqn_output_op)
    qnet = Base(data_shapes=data_shapes, sym=dqn_sym, name='QNet',
                  initializer=DQNInitializer(factor_type="in"),
                  ctx=q_ctx)
    target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx)
    # Create kvstore
    testShape = (1,1686180*100)
    testParam = nd.ones(testShape,ctx=q_ctx)
    testGrad = nd.zeros(testShape,ctx=q_ctx)

    # Create kvstore

    if args.kv_type != None:
        kvType = args.kv_type
        kvStore = kvstore.create(kvType)
        #Initialize kvstore
        for idx,v in enumerate(qnet.params.values()):
            kvStore.init(idx,v);
        # Set optimizer on kvstore
        kvStore.set_optimizer(optimizer)
        kvstore_update_period = args.kvstore_update_period
    else:
        updater = mx.optimizer.get_updater(optimizer)

    # if args.kv_type != None:
    #     kvType = args.kv_type
    #     kvStore = kvstore.create(kvType)
    #     kvStore.init(0,testParam)
    #     testOptimizer = mx.optimizer.Nop()
    #     kvStore.set_optimizer(testOptimizer)
    #     kvstore_update_period = args.kvstore_update_period


    qnet.print_stat()
    target_qnet.print_stat()
    # Begin Playing Game
    training_steps = 0
    total_steps = 0
    while(1):
        time_before_wait = time.time()

        # kvStore.push(0,testGrad,priority=0)
        # kvStore.pull(0,testParam,priority=0)
        # testParam.wait_to_read()

        for paramIndex in range(len(qnet.params)):#range(6):#
            k=qnet.params.keys()[paramIndex]
            kvStore.push(paramIndex,qnet.params_grad[k],priority=-paramIndex)
            kvStore.pull(paramIndex,qnet.params[k],priority=-paramIndex)

        for v in qnet.params.values():
            v.wait_to_read()
        logging.info("wait time %f" %(time.time()-time_before_wait))

    for epoch in xrange(epoch_num):
        # Run Epoch
        steps_left = steps_per_epoch
        episode = 0
        epoch_reward = 0
        start = time.time()
        game.start()
        while steps_left > 0:
            # Running New Episode
            episode += 1
            episode_loss = 0.0
            episode_q_value = 0.0
            episode_update_step = 0
            episode_action_step = 0
            time_episode_start = time.time()
            game.begin_episode(steps_left)
            while not game.episode_terminate:
                # 1. We need to choose a new action based on the current game status
                if game.state_enabled and game.replay_memory.sample_enabled:
                    do_exploration = (npy_rng.rand() < eps_curr)
                    eps_curr = max(eps_curr - eps_decay, eps_min)
                    if do_exploration:
                        action = npy_rng.randint(action_num)
                    else:
                        # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each
                        # We can simply stack the current_state() of gaming instances and give prediction for all of them
                        # We need to wait after calling calc_score(.), which makes the program slow
                        # TODO Profiling the speed of this part!
                        current_state = game.current_state()
                        state = nd.array(current_state.reshape((1,) + current_state.shape),
                                         ctx=q_ctx) / float(255.0)
                        qval_npy = qnet.forward(batch_size=1, data=state)[0].asnumpy()
                        action = numpy.argmax(qval_npy)
                        episode_q_value += qval_npy[0, action]
                        episode_action_step += 1
                else:
                    action = npy_rng.randint(action_num)

                # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times)
                game.play(action)
                total_steps += 1

                # 3. Update our Q network if we can start sampling from the replay memory
                #    Also, we update every `update_interval`
                if total_steps % update_interval == 0 and game.replay_memory.sample_enabled:
                    # 3.1 Draw sample from the replay_memory
                    training_steps += 1
                    episode_update_step += 1
                    states, actions, rewards, next_states, terminate_flags \
                        = game.replay_memory.sample(batch_size=minibatch_size)
                    states = nd.array(states, ctx=q_ctx) / float(255.0)
                    next_states = nd.array(next_states, ctx=q_ctx) / float(255.0)
                    actions = nd.array(actions, ctx=q_ctx)
                    rewards = nd.array(rewards, ctx=q_ctx)
                    terminate_flags = nd.array(terminate_flags, ctx=q_ctx)

                    # 3.2 Use the target network to compute the scores and
                    #     get the corresponding target rewards
                    if not args.double_q:
                        target_qval = target_qnet.forward(batch_size=minibatch_size,
                                                         data=next_states)[0]
                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(target_qval))\
                                           * (1.0 - terminate_flags) * discount
                    else:
                        target_qval = target_qnet.forward(batch_size=minibatch_size,
                                                         data=next_states)[0]
                        qval = qnet.forward(batch_size=minibatch_size, data=next_states)[0]

                        target_rewards = rewards + nd.choose_element_0index(target_qval,
                                                                nd.argmax_channel(qval))\
                                           * (1.0 - terminate_flags) * discount
                    outputs = qnet.forward(batch_size=minibatch_size,is_train=True, data=states,
                                              dqn_action=actions,
                                              dqn_reward=target_rewards)
                    qnet.backward(batch_size=minibatch_size)
                    nd.waitall()
                    time_before_update = time.time()

                    if args.kv_type != None:
                        if total_steps % kvstore_update_period == 0:
                            update_to_kvstore(kvStore,qnet.params,qnet.params_grad)
                    else:
                        qnet.update(updater=updater)
                    logging.info("update time %f" %(time.time()-time_before_update))
                    time_before_wait = time.time()
                    nd.waitall()
                    logging.info("wait time %f" %(time.time()-time_before_wait))

                    '''nd.waitall()
                    time_before_wait = time.time()
                    kvStore.push(0,testGrad,priority=0)
                    kvStore.pull(0,testParam,priority=0)
                    nd.waitall()
                    logging.info("wait time %f" %(time.time()-time_before_wait))'''
                    # 3.3 Calculate Loss
                    diff = nd.abs(nd.choose_element_0index(outputs[0], actions) - target_rewards)
                    quadratic_part = nd.clip(diff, -1, 1)
                    loss = (0.5 * nd.sum(nd.square(quadratic_part)) + nd.sum(diff - quadratic_part)).asscalar()
                    episode_loss += loss

                    # 3.3 Update the target network every freeze_interval
                    # (We can do annealing instead of hard copy)
                    if training_steps % freeze_interval == 0:
                        qnet.copy_params_to(target_qnet)
            steps_left -= game.episode_step
            time_episode_end = time.time()
            # Update the statistics
            epoch_reward += game.episode_reward
            info_str = "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \
                        % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward,
                           game.episode_step / (time_episode_end - time_episode_start), eps_curr)
            if episode_update_step > 0:
                info_str += ", Avg Loss:%f/%d" % (episode_loss / episode_update_step,
                                                  episode_update_step)
            if episode_action_step > 0:
                info_str += ", Avg Q Value:%f/%d" % (episode_q_value / episode_action_step,
                                                  episode_action_step)
            logging.info(info_str)
        end = time.time()
        fps = steps_per_epoch / (end - start)
        qnet.save_params(dir_path=args.dir_path, epoch=epoch)
        logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d"
                     % (epoch, fps, epoch_reward / float(episode), episode))