def main(): epoches = 32 gpu_id = 7 ctx_list = [mx.gpu(x) for x in [7, 8]] log_interval = 100 batch_size = 32 start_epoch = 0 # trainer_resume = resume + ".states" if resume is not None else None trainer_resume = None resume = None from mxnet.gluon.data.vision import transforms transform_fn = transforms.Compose([ LeftTopPad(dest_shape=(256, 256)), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) dataset = CaptionDataSet( image_root="/data3/zyx/yks/coco2017/train2017", annotation_path= "/data3/zyx/yks/coco2017/annotations/captions_train2017.json", transforms=transform_fn, feature_hdf5="output/train2017.h5") val_dataset = CaptionDataSet( image_root="/data3/zyx/yks/coco2017/val2017", annotation_path= "/data3/zyx/yks/coco2017/annotations/captions_val2017.json", words2index=dataset.words2index, index2words=dataset.index2words, transforms=transform_fn, feature_hdf5="output/val2017.h5") dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True, last_batch="discard") val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True) num_words = dataset.words_count # set up logger save_prefix = "output/res50_" logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = save_prefix + '_train.log' log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) net = EncoderDecoder(num_words=num_words, test_max_len=val_dataset.max_len) if resume is not None: net.collect_params().load(resume, allow_missing=True, ignore_extra=True) logger.info("Resumed form checkpoint {}.".format(resume)) params = net.collect_params() for key in params.keys(): if params[key]._data is not None: continue else: if "bias" in key or "mean" in key or "beta" in key: params[key].initialize(init=mx.init.Zero()) logging.info("initialized {} using Zero.".format(key)) elif "weight" in key: params[key].initialize(init=mx.init.Normal()) logging.info("initialized {} using Normal.".format(key)) elif "var" in key or "gamma" in key: params[key].initialize(init=mx.init.One()) logging.info("initialized {} using One.".format(key)) else: params[key].initialize(init=mx.init.Normal()) logging.info("initialized {} using Normal.".format(key)) net.collect_params().reset_ctx(ctx=ctx_list) trainer = mx.gluon.Trainer( net.collect_params(), 'adam', { 'learning_rate': 4e-4, 'clip_gradient': 5, 'multi_precision': True }, ) if trainer_resume is not None: trainer.load_states(trainer_resume) logger.info( "Loaded trainer states form checkpoint {}.".format(trainer_resume)) criterion = Criterion() accu_top3_metric = TopKAccuracy(top_k=3) accu_top1_metric = Accuracy(name="batch_accu") ctc_loss_metric = Loss(name="ctc_loss") alpha_metric = Loss(name="alpha_loss") batch_bleu = BleuMetric(name="batch_bleu", pred_index2words=dataset.index2words, label_index2words=dataset.index2words) epoch_bleu = BleuMetric(name="epoch_bleu", pred_index2words=dataset.index2words, label_index2words=dataset.index2words) btic = time.time() logger.info(batch_size) logger.info(num_words) logger.info(len(dataset.words2index)) logger.info(len(dataset.index2words)) logger.info(dataset.words2index["<PAD>"]) logger.info(val_dataset.words2index["<PAD>"]) logger.info(len(val_dataset.words2index)) # net.hybridize(static_alloc=True, static_shape=True) net_parallel = DataParallelModel(net, ctx_list=ctx_list, sync=True) for nepoch in range(start_epoch, epoches): if nepoch > 15: trainer.set_learning_rate(4e-5) logger.info("Current lr: {}".format(trainer.learning_rate)) accu_top1_metric.reset() accu_top3_metric.reset() ctc_loss_metric.reset() alpha_metric.reset() epoch_bleu.reset() batch_bleu.reset() for nbatch, batch in enumerate(tqdm.tqdm(dataloader)): batch = [mx.gluon.utils.split_and_load(x, ctx_list) for x in batch] inputs = [[x[n] for x in batch] for n, _ in enumerate(ctx_list)] losses = [] with ag.record(): net_parallel.sync = nbatch > 1 outputs = net_parallel(*inputs) for s_batch, s_outputs in zip(inputs, outputs): image, label, label_len = s_batch predictions, alphas = s_outputs ctc_loss = criterion(predictions, label, label_len) loss2 = 1.0 * ((1. - alphas.sum(axis=1))**2).mean() losses.extend([ctc_loss, loss2]) ag.backward(losses) trainer.step(batch_size=batch_size, ignore_stale_grad=True) for n, l in enumerate(label_len): l = int(l.asscalar()) la = label[n, 1:l] pred = predictions[n, :(l - 1)] accu_top3_metric.update(la, pred) accu_top1_metric.update(la, pred) epoch_bleu.update(la, predictions[n, :]) batch_bleu.update(la, predictions[n, :]) ctc_loss_metric.update(None, preds=nd.sum(ctc_loss) / image.shape[0]) alpha_metric.update(None, preds=loss2) if nbatch % log_interval == 0 and nbatch > 0: msg = ','.join([ '{}={:.3f}'.format(*metric.get()) for metric in [ epoch_bleu, batch_bleu, accu_top1_metric, accu_top3_metric, ctc_loss_metric, alpha_metric ] ]) logger.info( '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'. format(nepoch, nbatch, log_interval * batch_size / (time.time() - btic), msg)) btic = time.time() batch_bleu.reset() accu_top1_metric.reset() accu_top3_metric.reset() ctc_loss_metric.reset() alpha_metric.reset() bleu, acc_top1 = validate(net, gpu_id=gpu_id, val_loader=val_loader, train_index2words=dataset.index2words, val_index2words=val_dataset.index2words) save_path = save_prefix + "_weights-%d-bleu-%.4f-%.4f.params" % ( nepoch, bleu, acc_top1) net.collect_params().save(save_path) trainer.save_states(fname=save_path + ".states") logger.info("Saved checkpoint to {}.".format(save_path))
def main(): epoches = 32 gpu_id = 7 ctx_list = [mx.gpu(x) for x in [7, 8]] log_interval = 100 batch_size = 32 start_epoch = 0 # trainer_resume = resume + ".states" if resume is not None else None trainer_resume = None resume = None from mxnet.gluon.data.vision import transforms transform_fn = transforms.Compose([ LeftTopPad(dest_shape=(256, 256)), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) dataset = CaptionDataSet( image_root="/data3/zyx/yks/coco2017/train2017", annotation_path= "/data3/zyx/yks/coco2017/annotations/captions_train2017.json", transforms=transform_fn, feature_hdf5="output/train2017.h5") val_dataset = CaptionDataSet( image_root="/data3/zyx/yks/coco2017/val2017", annotation_path= "/data3/zyx/yks/coco2017/annotations/captions_val2017.json", words2index=dataset.words2index, index2words=dataset.index2words, transforms=transform_fn, feature_hdf5="output/val2017.h5") dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True, last_batch="discard") val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True) num_words = dataset.words_count # set up logger save_prefix = "output/res50_" logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = save_prefix + '_train.log' log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) net = EncoderDecoder(num_words=num_words, test_max_len=val_dataset.max_len).cuda() for name, p in net.named_parameters(): if "bias" in name: p.data.zero_() else: p.data.normal_(0, 0.01) print(name) net = torch.nn.DataParallel(net) if resume is not None: net.collect_params().load(resume, allow_missing=True, ignore_extra=True) logger.info("Resumed form checkpoint {}.".format(resume)) trainer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, net.parameters()), lr=4e-4) criterion = Criterion() accu_top3_metric = TopKAccuracy(top_k=3) accu_top1_metric = Accuracy(name="batch_accu") ctc_loss_metric = Loss(name="ctc_loss") alpha_metric = Loss(name="alpha_loss") batch_bleu = BleuMetric(name="batch_bleu", pred_index2words=dataset.index2words, label_index2words=dataset.index2words) epoch_bleu = BleuMetric(name="epoch_bleu", pred_index2words=dataset.index2words, label_index2words=dataset.index2words) btic = time.time() logger.info(batch_size) logger.info(num_words) logger.info(len(dataset.words2index)) logger.info(len(dataset.index2words)) logger.info(dataset.words2index["<PAD>"]) logger.info(val_dataset.words2index["<PAD>"]) logger.info(len(val_dataset.words2index)) for nepoch in range(start_epoch, epoches): if nepoch > 15: trainer.set_learning_rate(4e-5) logger.info("Current lr: {}".format(trainer.param_groups[0]["lr"])) accu_top1_metric.reset() accu_top3_metric.reset() ctc_loss_metric.reset() alpha_metric.reset() epoch_bleu.reset() batch_bleu.reset() for nbatch, batch in enumerate(tqdm.tqdm(dataloader)): batch = [ Variable(torch.from_numpy(x.asnumpy()).cuda()) for x in batch ] data, label, label_len = batch label = label.long() label_len = label_len.long() max_len = label_len.max().data.cpu().numpy() net.train() outputs = net(data, label, max_len) predictions, alphas = outputs ctc_loss = criterion(predictions, label, label_len) loss2 = 1.0 * ((1. - alphas.sum(dim=1))**2).mean() ((ctc_loss + loss2) / batch_size).backward() for group in trainer.param_groups: for param in group['params']: if param.grad is not None: param.grad.data.clamp_(-5, 5) trainer.step() if nbatch % 10 == 0: for n, l in enumerate(label_len): l = int(l.data.cpu().numpy()) la = label[n, 1:l].data.cpu().numpy() pred = predictions[n, :(l - 1)].data.cpu().numpy() accu_top3_metric.update(mx.nd.array(la), mx.nd.array(pred)) accu_top1_metric.update(mx.nd.array(la), mx.nd.array(pred)) epoch_bleu.update(la, predictions[n, :].data.cpu().numpy()) batch_bleu.update(la, predictions[n, :].data.cpu().numpy()) ctc_loss_metric.update( None, preds=mx.nd.array([ctc_loss.data.cpu().numpy()]) / batch_size) alpha_metric.update(None, preds=mx.nd.array( [loss2.data.cpu().numpy()])) if nbatch % log_interval == 0 and nbatch > 0: msg = ','.join([ '{}={:.3f}'.format(*metric.get()) for metric in [ epoch_bleu, batch_bleu, accu_top1_metric, accu_top3_metric, ctc_loss_metric, alpha_metric ] ]) logger.info( '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'. format( nepoch, nbatch, log_interval * batch_size / (time.time() - btic), msg)) btic = time.time() batch_bleu.reset() accu_top1_metric.reset() accu_top3_metric.reset() ctc_loss_metric.reset() alpha_metric.reset() net.eval() bleu, acc_top1 = validate(net, gpu_id=gpu_id, val_loader=val_loader, train_index2words=dataset.index2words, val_index2words=val_dataset.index2words) save_path = save_prefix + "_weights-%d-bleu-%.4f-%.4f.params" % ( nepoch, bleu, acc_top1) torch.save(net.module.state_dict(), save_path) torch.save(trainer.state_dict(), save_path + ".states") logger.info("Saved checkpoint to {}.".format(save_path))