def train_rnn(args): #hold the cnn ,train the rnn vis = Visulizer(env=opt.env) vis.log(opt) batch, context, steps = parse_basic(opt) context=context[0] train_loader,val_loader = get_meitu_dataloader(opt.meitu_dir, opt.decoder_gpu, batch_size=batch, num_workers=opt.num_workers, n_frame=opt.n_frame, crop_size=opt.crop_size, scale_w=opt.scale_w, scale_h = opt.scale_h) encoder = Encoder(num_class=63,model_depth=34,embed_size=200) state_trans = State_Trans(512,256) decoder = Decoder(embed_size=256,hidden_size=256,vocab_size=66,num_layers=1,max_seq_length=5) encoder.initialize(mx.init.Xavier(),ctx=context) state_trans.initialize(mx.init.Xavier(),ctx=context) decoder.initialize(mx.init.Xavier(),ctx=context) max_seq_len = opt.max_seq_len lr_opts = {'learning_rate': opt.lr, 'momentum': 0.9, 'wd': opt.wd} if not opt.encoder_pre is None: encoder.custom_load_params(opt.encoder_pre) if not opt.state_trans_pre is None: state_trans.load_parameters(opt.state_trans_pre) if not opt.decoder_pre is None: decoder.load_parameters(opt.decoder_pre) trainer1 = Trainer(state_trans.collect_params(),'sgd',lr_opts,kvstore=opt.kvstore) trainer2 = Trainer(decoder.collect_params(),'sgd',lr_opts,kvstore=opt.kvstore) lr_steps = MultiFactorScheduler(steps,factor=opt.lr_scheduler_factor) loss_criterion = gloss.SoftmaxCrossEntropyLoss() for epoch in range(opt.num_epoch): l_sum = 0 tic = time() pre_loss, cumulative_loss = 0.0, 0.0 trainer1.set_learning_rate(lr_steps(epoch)) trainer2.set_learning_rate(lr_steps(epoch)) vis.log('[Epoch %d,set learning rate'%(epoch,trainer1.learning_rate)) for i,(data,label) in enumerate(train_loader): #train rnn and cnn-rnn is one context data = data.as_in_context(context) label = label.as_in_context(context) features = encoder(data)#type [N,C] with autograd.record(): inputs = nd.ones(shape=(1,batch),ctx=context)*bos mask = nd.ones(shape=(1,batch),ctx=context) val_length = nd.array([0],ctx=context) feat_states = state_trans(features) states = decoder.begin_state(batch_size=batch,func=nd.zeros,vide_feat=feat_states) loss = nd.array([0],ctx=context) for i in range(max_seq_len): y =label[i] outputs,states = decoder(inputs,states) #outputs shape is 1NC,states is list of [LNC] inputs = outputs.argmax(axis=2) # shape is 1xN just for annother input val_length = val_length +mask.sum() outputs = outputs.reshape(batch,-1) loss = loss + (loss_criterion(outputs,y)*mask).sum() mask = mask * (inputs != eos) loss = loss/val_length loss.backward() trainer1.step(1) trainer2.step(1) l_sum += loss.asscalar() if (i+1)%(opt.log_interval)==0: vis.log('Epoch %d,Iter %d,Training loss=%f' % (epoch, i + 1, cumulative_loss - pre_loss)) pre_loss = cumulative_loss if opt.debug: break