def train_rnn(args):
    #hold the cnn ,train the rnn
    vis = Visulizer(env=opt.env)
    vis.log(opt)
    batch, context, steps = parse_basic(opt)
    context=context[0]
    train_loader,val_loader = get_meitu_dataloader(opt.meitu_dir,
                                                   opt.decoder_gpu,
                                                   batch_size=batch,
                                                   num_workers=opt.num_workers,
                                                   n_frame=opt.n_frame,
                                                   crop_size=opt.crop_size,
                                                   scale_w=opt.scale_w,
                                                   scale_h = opt.scale_h)
    encoder = Encoder(num_class=63,model_depth=34,embed_size=200)
    state_trans = State_Trans(512,256)
    decoder = Decoder(embed_size=256,hidden_size=256,vocab_size=66,num_layers=1,max_seq_length=5)
    encoder.initialize(mx.init.Xavier(),ctx=context)
    state_trans.initialize(mx.init.Xavier(),ctx=context)
    decoder.initialize(mx.init.Xavier(),ctx=context)
    max_seq_len = opt.max_seq_len
    lr_opts = {'learning_rate': opt.lr, 'momentum': 0.9, 'wd': opt.wd}
    if not opt.encoder_pre is None:
        encoder.custom_load_params(opt.encoder_pre)
    if not opt.state_trans_pre is None:
        state_trans.load_parameters(opt.state_trans_pre)
    if not opt.decoder_pre is None:
        decoder.load_parameters(opt.decoder_pre)

    trainer1 = Trainer(state_trans.collect_params(),'sgd',lr_opts,kvstore=opt.kvstore)
    trainer2 = Trainer(decoder.collect_params(),'sgd',lr_opts,kvstore=opt.kvstore)
    lr_steps = MultiFactorScheduler(steps,factor=opt.lr_scheduler_factor)
    loss_criterion = gloss.SoftmaxCrossEntropyLoss()
    for epoch in range(opt.num_epoch):
        l_sum = 0
        tic = time()
        pre_loss, cumulative_loss = 0.0, 0.0
        trainer1.set_learning_rate(lr_steps(epoch))
        trainer2.set_learning_rate(lr_steps(epoch))
        vis.log('[Epoch %d,set learning rate'%(epoch,trainer1.learning_rate))

        for i,(data,label) in enumerate(train_loader):
            #train rnn and cnn-rnn is one context
            data = data.as_in_context(context)
            label = label.as_in_context(context)
            features = encoder(data)#type [N,C]

            with autograd.record():
                inputs = nd.ones(shape=(1,batch),ctx=context)*bos
                mask = nd.ones(shape=(1,batch),ctx=context)
                val_length = nd.array([0],ctx=context)
                feat_states = state_trans(features)
                states = decoder.begin_state(batch_size=batch,func=nd.zeros,vide_feat=feat_states)
                loss = nd.array([0],ctx=context)
                for i in range(max_seq_len):
                    y =label[i]
                    outputs,states = decoder(inputs,states)
                    #outputs shape is 1NC,states is list of [LNC]
                    inputs = outputs.argmax(axis=2) # shape is 1xN just for annother input
                    val_length = val_length +mask.sum()
                    outputs = outputs.reshape(batch,-1)
                    loss = loss + (loss_criterion(outputs,y)*mask).sum()
                    mask = mask * (inputs != eos)
                loss = loss/val_length
                loss.backward()
                trainer1.step(1)
                trainer2.step(1)
            l_sum += loss.asscalar()

            if (i+1)%(opt.log_interval)==0:
                vis.log('Epoch %d,Iter %d,Training loss=%f' % (epoch, i + 1,
                                                               cumulative_loss - pre_loss))
                pre_loss = cumulative_loss
                if opt.debug:
                    break