help='Enable built-in profiler (0=off, 1=on)')
opt = parser.parse_args()

# global variables
logger.info('Starting new image-classification task:, %s', opt)
mx.random.seed(opt.seed)
model_name = opt.model
dataset_classes = {'mnist': 10, 'cifar10': 10, 'imagenet': 1000, 'dummy': 1000}
batch_size, dataset, classes = opt.batch_size, opt.dataset, dataset_classes[
    opt.dataset]
context = [mx.gpu(int(i))
           for i in opt.gpus.split(',')] if opt.gpus.strip() else [mx.cpu()]
num_gpus = len(context)
batch_size *= max(1, num_gpus)
lr_steps = [int(x) for x in opt.lr_steps.split(',') if x.strip()]
metric = CompositeEvalMetric([Accuracy(), TopKAccuracy(5)])


def get_model(model, ctx, opt):
    """Model initialization."""
    kwargs = {'ctx': ctx, 'pretrained': opt.use_pretrained, 'classes': classes}
    if model.startswith('resnet'):
        kwargs['thumbnail'] = opt.use_thumbnail
    elif model.startswith('vgg'):
        kwargs['batch_norm'] = opt.batch_norm

    net = models.get_model(model, **kwargs)
    if opt.resume:
        net.load_params(opt.resume)
    elif not opt.use_pretrained:
        if model in ['alexnet']:
Esempio n. 2
0
model_name = opt.model
dataset_classes = {
    'mnist': 10,
    'cifar10': 10,
    'caltech101': 101,
    'imagenet': 1000,
    'dummy': 1000
}
batch_size, dataset, classes = opt.batch_size, opt.dataset, dataset_classes[
    opt.dataset]
context = [mx.gpu(int(i))
           for i in opt.gpus.split(',')] if opt.gpus.strip() else [mx.cpu()]
num_gpus = len(context)
batch_size *= max(1, num_gpus)
lr_steps = [int(x) for x in opt.lr_steps.split(',') if x.strip()]
metric = CompositeEvalMetric([Accuracy(), TopKAccuracy(5), CrossEntropy()])
kv = mx.kv.create(opt.kvstore)


def get_data_iters(dataset, batch_size, opt):
    """get dataset iterators"""
    if dataset == 'mnist':
        train_data, val_data = get_mnist_iterator(batch_size, (1, 28, 28),
                                                  num_parts=kv.num_workers,
                                                  part_index=kv.rank)
    elif dataset == 'cifar10':
        train_data, val_data = get_cifar10_iterator(batch_size, (3, 32, 32),
                                                    num_parts=kv.num_workers,
                                                    part_index=kv.rank)
    elif dataset == 'imagenet':
        shape_dim = 299 if model_name == 'inceptionv3' else 224
Esempio n. 3
0
def main():
    epoches = 32
    gpu_id = 7
    ctx_list = [mx.gpu(x) for x in [7, 8]]
    log_interval = 100
    batch_size = 32
    start_epoch = 0
    # trainer_resume = resume + ".states" if resume is not None else None
    trainer_resume = None

    resume = None
    from mxnet.gluon.data.vision import transforms
    transform_fn = transforms.Compose([
        LeftTopPad(dest_shape=(256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])
    dataset = CaptionDataSet(
        image_root="/data3/zyx/yks/coco2017/train2017",
        annotation_path=
        "/data3/zyx/yks/coco2017/annotations/captions_train2017.json",
        transforms=transform_fn,
        feature_hdf5="output/train2017.h5")
    val_dataset = CaptionDataSet(
        image_root="/data3/zyx/yks/coco2017/val2017",
        annotation_path=
        "/data3/zyx/yks/coco2017/annotations/captions_val2017.json",
        words2index=dataset.words2index,
        index2words=dataset.index2words,
        transforms=transform_fn,
        feature_hdf5="output/val2017.h5")
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1,
                            pin_memory=True,
                            last_batch="discard")
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1,
                            pin_memory=True)

    num_words = dataset.words_count

    # set up logger
    save_prefix = "output/res50_"
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)

    net = EncoderDecoder(num_words=num_words, test_max_len=val_dataset.max_len)
    if resume is not None:
        net.collect_params().load(resume,
                                  allow_missing=True,
                                  ignore_extra=True)
        logger.info("Resumed form checkpoint {}.".format(resume))
    params = net.collect_params()
    for key in params.keys():
        if params[key]._data is not None:
            continue
        else:
            if "bias" in key or "mean" in key or "beta" in key:
                params[key].initialize(init=mx.init.Zero())
                logging.info("initialized {} using Zero.".format(key))
            elif "weight" in key:
                params[key].initialize(init=mx.init.Normal())
                logging.info("initialized {} using Normal.".format(key))
            elif "var" in key or "gamma" in key:
                params[key].initialize(init=mx.init.One())
                logging.info("initialized {} using One.".format(key))
            else:
                params[key].initialize(init=mx.init.Normal())
                logging.info("initialized {} using Normal.".format(key))

    net.collect_params().reset_ctx(ctx=ctx_list)
    trainer = mx.gluon.Trainer(
        net.collect_params(),
        'adam',
        {
            'learning_rate': 4e-4,
            'clip_gradient': 5,
            'multi_precision': True
        },
    )
    if trainer_resume is not None:
        trainer.load_states(trainer_resume)
        logger.info(
            "Loaded trainer states form checkpoint {}.".format(trainer_resume))
    criterion = Criterion()
    accu_top3_metric = TopKAccuracy(top_k=3)
    accu_top1_metric = Accuracy(name="batch_accu")
    ctc_loss_metric = Loss(name="ctc_loss")
    alpha_metric = Loss(name="alpha_loss")
    batch_bleu = BleuMetric(name="batch_bleu",
                            pred_index2words=dataset.index2words,
                            label_index2words=dataset.index2words)
    epoch_bleu = BleuMetric(name="epoch_bleu",
                            pred_index2words=dataset.index2words,
                            label_index2words=dataset.index2words)
    btic = time.time()
    logger.info(batch_size)
    logger.info(num_words)
    logger.info(len(dataset.words2index))
    logger.info(len(dataset.index2words))
    logger.info(dataset.words2index["<PAD>"])
    logger.info(val_dataset.words2index["<PAD>"])
    logger.info(len(val_dataset.words2index))
    # net.hybridize(static_alloc=True, static_shape=True)
    net_parallel = DataParallelModel(net, ctx_list=ctx_list, sync=True)
    for nepoch in range(start_epoch, epoches):
        if nepoch > 15:
            trainer.set_learning_rate(4e-5)
        logger.info("Current lr: {}".format(trainer.learning_rate))
        accu_top1_metric.reset()
        accu_top3_metric.reset()
        ctc_loss_metric.reset()
        alpha_metric.reset()
        epoch_bleu.reset()
        batch_bleu.reset()
        for nbatch, batch in enumerate(tqdm.tqdm(dataloader)):
            batch = [mx.gluon.utils.split_and_load(x, ctx_list) for x in batch]
            inputs = [[x[n] for x in batch] for n, _ in enumerate(ctx_list)]
            losses = []
            with ag.record():
                net_parallel.sync = nbatch > 1
                outputs = net_parallel(*inputs)
                for s_batch, s_outputs in zip(inputs, outputs):
                    image, label, label_len = s_batch
                    predictions, alphas = s_outputs
                    ctc_loss = criterion(predictions, label, label_len)
                    loss2 = 1.0 * ((1. - alphas.sum(axis=1))**2).mean()
                    losses.extend([ctc_loss, loss2])
            ag.backward(losses)
            trainer.step(batch_size=batch_size, ignore_stale_grad=True)
            for n, l in enumerate(label_len):
                l = int(l.asscalar())
                la = label[n, 1:l]
                pred = predictions[n, :(l - 1)]
                accu_top3_metric.update(la, pred)
                accu_top1_metric.update(la, pred)
                epoch_bleu.update(la, predictions[n, :])
                batch_bleu.update(la, predictions[n, :])
            ctc_loss_metric.update(None,
                                   preds=nd.sum(ctc_loss) / image.shape[0])
            alpha_metric.update(None, preds=loss2)
            if nbatch % log_interval == 0 and nbatch > 0:
                msg = ','.join([
                    '{}={:.3f}'.format(*metric.get()) for metric in [
                        epoch_bleu, batch_bleu, accu_top1_metric,
                        accu_top3_metric, ctc_loss_metric, alpha_metric
                    ]
                ])
                logger.info(
                    '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.
                    format(nepoch, nbatch,
                           log_interval * batch_size / (time.time() - btic),
                           msg))
                btic = time.time()
                batch_bleu.reset()
                accu_top1_metric.reset()
                accu_top3_metric.reset()
                ctc_loss_metric.reset()
                alpha_metric.reset()

        bleu, acc_top1 = validate(net,
                                  gpu_id=gpu_id,
                                  val_loader=val_loader,
                                  train_index2words=dataset.index2words,
                                  val_index2words=val_dataset.index2words)
        save_path = save_prefix + "_weights-%d-bleu-%.4f-%.4f.params" % (
            nepoch, bleu, acc_top1)
        net.collect_params().save(save_path)
        trainer.save_states(fname=save_path + ".states")
        logger.info("Saved checkpoint to {}.".format(save_path))
Esempio n. 4
0
def main():
    epoches = 32
    gpu_id = 7
    ctx_list = [mx.gpu(x) for x in [7, 8]]
    log_interval = 100
    batch_size = 32
    start_epoch = 0
    # trainer_resume = resume + ".states" if resume is not None else None
    trainer_resume = None

    resume = None
    from mxnet.gluon.data.vision import transforms
    transform_fn = transforms.Compose([
        LeftTopPad(dest_shape=(256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])
    dataset = CaptionDataSet(
        image_root="/data3/zyx/yks/coco2017/train2017",
        annotation_path=
        "/data3/zyx/yks/coco2017/annotations/captions_train2017.json",
        transforms=transform_fn,
        feature_hdf5="output/train2017.h5")
    val_dataset = CaptionDataSet(
        image_root="/data3/zyx/yks/coco2017/val2017",
        annotation_path=
        "/data3/zyx/yks/coco2017/annotations/captions_val2017.json",
        words2index=dataset.words2index,
        index2words=dataset.index2words,
        transforms=transform_fn,
        feature_hdf5="output/val2017.h5")
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1,
                            pin_memory=True,
                            last_batch="discard")
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1,
                            pin_memory=True)

    num_words = dataset.words_count

    # set up logger
    save_prefix = "output/res50_"
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = save_prefix + '_train.log'
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)

    net = EncoderDecoder(num_words=num_words,
                         test_max_len=val_dataset.max_len).cuda()
    for name, p in net.named_parameters():
        if "bias" in name:
            p.data.zero_()
        else:
            p.data.normal_(0, 0.01)
        print(name)
    net = torch.nn.DataParallel(net)
    if resume is not None:
        net.collect_params().load(resume,
                                  allow_missing=True,
                                  ignore_extra=True)
        logger.info("Resumed form checkpoint {}.".format(resume))

    trainer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                             net.parameters()),
                               lr=4e-4)
    criterion = Criterion()
    accu_top3_metric = TopKAccuracy(top_k=3)
    accu_top1_metric = Accuracy(name="batch_accu")
    ctc_loss_metric = Loss(name="ctc_loss")
    alpha_metric = Loss(name="alpha_loss")
    batch_bleu = BleuMetric(name="batch_bleu",
                            pred_index2words=dataset.index2words,
                            label_index2words=dataset.index2words)
    epoch_bleu = BleuMetric(name="epoch_bleu",
                            pred_index2words=dataset.index2words,
                            label_index2words=dataset.index2words)
    btic = time.time()
    logger.info(batch_size)
    logger.info(num_words)
    logger.info(len(dataset.words2index))
    logger.info(len(dataset.index2words))
    logger.info(dataset.words2index["<PAD>"])
    logger.info(val_dataset.words2index["<PAD>"])
    logger.info(len(val_dataset.words2index))
    for nepoch in range(start_epoch, epoches):
        if nepoch > 15:
            trainer.set_learning_rate(4e-5)
        logger.info("Current lr: {}".format(trainer.param_groups[0]["lr"]))
        accu_top1_metric.reset()
        accu_top3_metric.reset()
        ctc_loss_metric.reset()
        alpha_metric.reset()
        epoch_bleu.reset()
        batch_bleu.reset()
        for nbatch, batch in enumerate(tqdm.tqdm(dataloader)):
            batch = [
                Variable(torch.from_numpy(x.asnumpy()).cuda()) for x in batch
            ]
            data, label, label_len = batch
            label = label.long()
            label_len = label_len.long()
            max_len = label_len.max().data.cpu().numpy()
            net.train()
            outputs = net(data, label, max_len)
            predictions, alphas = outputs
            ctc_loss = criterion(predictions, label, label_len)
            loss2 = 1.0 * ((1. - alphas.sum(dim=1))**2).mean()
            ((ctc_loss + loss2) / batch_size).backward()
            for group in trainer.param_groups:
                for param in group['params']:
                    if param.grad is not None:
                        param.grad.data.clamp_(-5, 5)

            trainer.step()
            if nbatch % 10 == 0:
                for n, l in enumerate(label_len):
                    l = int(l.data.cpu().numpy())
                    la = label[n, 1:l].data.cpu().numpy()
                    pred = predictions[n, :(l - 1)].data.cpu().numpy()
                    accu_top3_metric.update(mx.nd.array(la), mx.nd.array(pred))
                    accu_top1_metric.update(mx.nd.array(la), mx.nd.array(pred))
                    epoch_bleu.update(la, predictions[n, :].data.cpu().numpy())
                    batch_bleu.update(la, predictions[n, :].data.cpu().numpy())
                ctc_loss_metric.update(
                    None,
                    preds=mx.nd.array([ctc_loss.data.cpu().numpy()]) /
                    batch_size)
                alpha_metric.update(None,
                                    preds=mx.nd.array(
                                        [loss2.data.cpu().numpy()]))
                if nbatch % log_interval == 0 and nbatch > 0:
                    msg = ','.join([
                        '{}={:.3f}'.format(*metric.get()) for metric in [
                            epoch_bleu, batch_bleu, accu_top1_metric,
                            accu_top3_metric, ctc_loss_metric, alpha_metric
                        ]
                    ])
                    logger.info(
                        '[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.
                        format(
                            nepoch, nbatch,
                            log_interval * batch_size / (time.time() - btic),
                            msg))
                    btic = time.time()
                    batch_bleu.reset()
                    accu_top1_metric.reset()
                    accu_top3_metric.reset()
                    ctc_loss_metric.reset()
                    alpha_metric.reset()
        net.eval()
        bleu, acc_top1 = validate(net,
                                  gpu_id=gpu_id,
                                  val_loader=val_loader,
                                  train_index2words=dataset.index2words,
                                  val_index2words=val_dataset.index2words)
        save_path = save_prefix + "_weights-%d-bleu-%.4f-%.4f.params" % (
            nepoch, bleu, acc_top1)
        torch.save(net.module.state_dict(), save_path)
        torch.save(trainer.state_dict(), save_path + ".states")
        logger.info("Saved checkpoint to {}.".format(save_path))