Ejemplo n.º 1
0
def test(args):
    valid_dataset = SIDDValData(args.data)
    valid_sampler = data.SequentialSampler(
        valid_dataset, batch_size=1, drop_last=False
    )
    valid_dataloader = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        num_workers=8,
    )
    model = UNetD(3)
    with open(args.checkpoint, "rb") as f:
        state = pickle.load(f)
    model.load_state_dict(state["state_dict"])
    model.eval()

    def valid_step(image, label):
        pred = model(image)
        pred = image - pred
        psnr_it = batch_PSNR(pred, label)
        return psnr_it

    def valid(func, data_queue):
        psnr_v = 0.
        for step, (image, label) in tqdm(enumerate(data_queue)):
            image = megengine.tensor(image)
            label = megengine.tensor(label)
            psnr_it = func(image, label)
            psnr_v += psnr_it
        psnr_v /= step + 1
        return psnr_v

    psnr_v = valid(valid_step, valid_dataloader)
    print("PSNR: {:.3f}".format(psnr_v.item()) )
Ejemplo n.º 2
0
def build_dataloader(batch_size, dataset_dir):
    train_dataset = dataset.PascalVOC(dataset_dir,
                                      cfg.DATA_TYPE,
                                      order=["image", "mask"])
    train_sampler = data.RandomSampler(train_dataset,
                                       batch_size,
                                       drop_last=True)
    train_dataloader = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose(
            transforms=[
                T.RandomHorizontalFlip(0.5),
                T.RandomResize(scale_range=(0.5, 2)),
                T.RandomCrop(
                    output_size=(cfg.IMG_SIZE, cfg.IMG_SIZE),
                    padding_value=[0, 0, 0],
                    padding_maskvalue=255,
                ),
                T.Normalize(mean=cfg.IMG_MEAN, std=cfg.IMG_STD),
                T.ToMode(),
            ],
            order=["image", "mask"],
        ),
        num_workers=0,
    )
    return train_dataloader, train_dataset.__len__()
Ejemplo n.º 3
0
def worker(rank, world_size, args):
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    model = getattr(M, args.arch)(pretrained=(args.model is None))

    if args.model:
        logger.info("load weights from %s", args.model)
        model.load_state_dict(mge.load(args.model), strict=False)

    if args.quantized:
        quantize(model)

    @jit.trace(symbolic=True)
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    logger.info("preparing dataset..")
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=args.batch_size,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=[128.0, 128.0, 128.0], std=[1.0, 1.0,
                                                         1.0]),  # BGR
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("Valid %.3f / %.3f", valid_acc, valid_acc5)
Ejemplo n.º 4
0
def worker(world_size, args):
    # pylint: disable=too-many-statements

    rank = dist.get_rank()
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))

    model = models.__dict__[args.arch]()

    if args.mode != "normal":
        quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)

    if args.checkpoint:
        logger.info("Load pretrained weights from %s", args.checkpoint)
        ckpt = mge.load(args.checkpoint)
        ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
        model.load_state_dict(ckpt, strict=False)

    if args.mode == "quantized":
        quantize(model)

    # Define valid graph
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
        acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    # Build valid datasets
    logger.info("preparing dataset..")
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=128),
            T.ToMode("CHW")
        ]),
        num_workers=args.workers,
    )

    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    if rank == 0:
        logger.info("TEST %f, %f", valid_acc, valid_acc5)
Ejemplo n.º 5
0
def build_dataset(args):
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset,
                           batch_size=args.batch_size,
                           drop_last=True))
    train_dataloader = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([  # Baseline Augmentation for small models
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.Normalize(mean=[103.530, 116.280, 123.675],
                        std=[57.375, 57.120, 58.395]),  # BGR
            T.ToMode("CHW"),
        ]) if args.arch in ("resnet18", "resnet34") else T.Compose(
            [  # Facebook Augmentation for large models
                T.RandomResizedCrop(224),
                T.RandomHorizontalFlip(),
                T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                T.Normalize(mean=[103.530, 116.280, 123.675],
                            std=[57.375, 57.120, 58.395]),  # BGR
                T.ToMode("CHW"),
            ]),
        num_workers=args.workers,
    )
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_dataloader = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=[103.530, 116.280, 123.675],
                        std=[57.375, 57.120, 58.395]),  # BGR
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    return train_dataloader, valid_dataloader
Ejemplo n.º 6
0
def build_dataset(args):
    assert not args.batch_size//args.ngpus == 0 and not 4 // args.ngpus == 0
    train_dataset = SIDDData(args.data, length=args.batch_size*args.steps_per_epoch)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset, batch_size=args.batch_size//args.ngpus, drop_last=True)
    )
    train_dataloader = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        num_workers=args.workers,
    )
    valid_dataset = SIDDValData(args.data)
    valid_sampler = data.SequentialSampler(
        valid_dataset, batch_size=4//args.ngpus, drop_last=False
    )
    valid_dataloader = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        num_workers=args.workers,
    )
    return train_dataloader, valid_dataloader
Ejemplo n.º 7
0
def build_dataloader(dataset_dir):
    val_dataset = dataset.PascalVOC(dataset_dir,
                                    "val",
                                    order=["image", "mask"])
    val_sampler = data.SequentialSampler(val_dataset, cfg.VAL_BATCHES)
    val_dataloader = data.DataLoader(
        val_dataset,
        sampler=val_sampler,
        transform=T.Normalize(mean=cfg.IMG_MEAN,
                              std=cfg.IMG_STD,
                              order=["image", "mask"]),
        num_workers=cfg.DATA_WORKERS,
    )
    return val_dataloader, val_dataset.__len__()
Ejemplo n.º 8
0
    def prepare_dataset(name):
        """prepare dataset

        Args:
            name (str): name of the dataset, should be one of {facescrub, megaface}

        Returns:
            dataset (data.Dataset): required dataset
            queue (data.DataLoader): corresponding dataloader
        """
        preprocess = T.Compose([T.Normalize(mean=127.5, std=128), T.ToMode("CHW")])
        dataset = get_eval_dataset(name, dataset_dir=configs["dataset_dir"])
        sampler = data.SequentialSampler(dataset, batch_size=configs["batch_size"])
        queue = data.DataLoader(dataset, sampler=sampler, transform=preprocess)
        return dataset, queue
Ejemplo n.º 9
0
def build_dataset(args):
    train_dataloader = None
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_dataloader = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=[103.530, 116.280, 123.675],
                        std=[57.375, 57.120, 58.395]),  # BGR
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    return train_dataloader, valid_dataloader
Ejemplo n.º 10
0
def build_dataloader(dataset_dir, cfg):
    if cfg.DATASET == "VOC2012":
        val_dataset = EvalPascalVOC(dataset_dir,
                                    "val",
                                    order=["image", "mask", "info"])
    elif cfg.DATASET == "Cityscapes":
        val_dataset = dataset.Cityscapes(dataset_dir,
                                         "val",
                                         mode='gtFine',
                                         order=["image", "mask", "info"])
    else:
        raise ValueError("Unsupported dataset {}".format(cfg.DATASET))

    val_sampler = data.SequentialSampler(val_dataset, cfg.VAL_BATCHES)
    val_dataloader = data.DataLoader(
        val_dataset,
        sampler=val_sampler,
        transform=T.Normalize(mean=cfg.IMG_MEAN,
                              std=cfg.IMG_STD,
                              order=["image", "mask"]),
        num_workers=cfg.DATA_WORKERS,
    )
    return val_dataloader, val_dataset.__len__()
Ejemplo n.º 11
0
def build_dataloader(batch_size, dataset_dir, cfg):
    if cfg.DATASET == "VOC2012":
        train_dataset = dataset.PascalVOC(
            dataset_dir,
            cfg.DATA_TYPE,
            order=["image", "mask"]
        )
    elif cfg.DATASET == "Cityscapes":
        train_dataset = dataset.Cityscapes(
            dataset_dir,
            "train",
            mode='gtFine',
            order=["image", "mask"]
        )
    else:
        raise ValueError("Unsupported dataset {}".format(cfg.DATASET))
    train_sampler = data.RandomSampler(train_dataset, batch_size, drop_last=True)
    train_dataloader = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose(
            transforms=[
                T.RandomHorizontalFlip(0.5),
                T.RandomResize(scale_range=(0.5, 2)),
                T.RandomCrop(
                    output_size=(cfg.IMG_HEIGHT, cfg.IMG_WIDTH),
                    padding_value=[0, 0, 0],
                    padding_maskvalue=255,
                ),
                T.Normalize(mean=cfg.IMG_MEAN, std=cfg.IMG_STD),
                T.ToMode(),
            ],
            order=["image", "mask"],
        ),
        num_workers=0,
    )
    return train_dataloader, train_dataset.__len__()
Ejemplo n.º 12
0
def worker(rank, world_size, args):
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch)

    model = getattr(M, args.arch)()

    optimizer = optim.SGD(
        get_parameters(model),
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    # Define train and valid graph
    @jit.trace(symbolic=True)
    def train_func(image, label):
        model.train()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        optimizer.backward(loss)  # compute gradients
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "train_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "train_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "train_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    @jit.trace(symbolic=True)
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    # Build train and valid datasets
    logger.info("preparing dataset..")
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.RandomSampler(train_dataset,
                                       batch_size=args.batch_size,
                                       drop_last=True)
    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            T.Normalize(mean=[103.530, 116.280, 123.675],
                        std=[57.375, 57.120, 58.395]),  # BGR
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    train_queue = infinite_iter(train_queue)

    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=[103.530, 116.280, 123.675],
                        std=[57.375, 57.120, 58.395]),  # BGR
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()
    for step in range(0, args.steps + 1250 + 1):
        # Linear learning rate decay
        decay = 1.0
        decay = 1 - float(step) / args.steps if step < args.steps else 0.0
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.learning_rate * decay

        image, label = next(train_queue)
        image = image.astype("float32")
        label = label.astype("int32")

        n = image.shape[0]

        optimizer.zero_grad()
        loss, acc1, acc5 = train_func(image, label)
        optimizer.step()

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logger.info(
                "TRAIN %06d %f %s %s %s %s",
                step,
                args.learning_rate * decay,
                objs,
                top1,
                top5,
                total_time,
            )
            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()
        if step % 10000 == 0 and rank == 0:
            logger.info("SAVING %06d", step)
            mge.save(
                model.state_dict(),
                os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step)),
            )
        if step % 10000 == 0 and step != 0:
            _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)

    mge.save(model.state_dict(),
             os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step)))
    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
Ejemplo n.º 13
0
def worker(rank, world_size, args):
    # pylint: disable=too-many-statements
    mge.set_log_file(os.path.join(args.save, args.arch, "log.txt"))

    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch)

    model = getattr(M, args.arch)()
    step_start = 0
    if args.model:
        logger.info("load weights from %s", args.model)
        model.load_state_dict(mge.load(args.model))
        step_start = int(args.model.split("-")[1].split(".")[0])

    optimizer = optim.SGD(
        get_parameters(model),
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    # Define train and valid graph
    @jit.trace(symbolic=True)
    def train_func(image, label):
        model.train()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        optimizer.backward(loss)  # compute gradients
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "train_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "train_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "train_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    @jit.trace(symbolic=True)
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    # Build train and valid datasets
    logger.info("preparing dataset..")
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset,
                           batch_size=args.batch_size,
                           drop_last=True))
    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )

    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()
    for step in range(step_start, args.steps + 1):
        # Linear learning rate decay
        decay = 1.0
        decay = 1 - float(step) / args.steps if step < args.steps else 0
        for param_group in optimizer.param_groups:
            param_group["lr"] = args.learning_rate * decay

        image, label = next(train_queue)
        time_data = time.time() - t
        image = image.astype("float32")
        label = label.astype("int32")

        n = image.shape[0]

        optimizer.zero_grad()
        loss, acc1, acc5 = train_func(image, label)
        optimizer.step()

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        time_iter = time.time() - t
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logger.info(
                "TRAIN Iter %06d: lr = %f,\tloss = %f,\twc_loss = 1,\tTop-1 err = %f,\tTop-5 err = %f,\tdata_time = %f,\ttrain_time = %f,\tremain_hours=%f",
                step,
                args.learning_rate * decay,
                float(objs.__str__().split()[1]),
                1 - float(top1.__str__().split()[1]) / 100,
                1 - float(top5.__str__().split()[1]) / 100,
                time_data,
                time_iter - time_data,
                time_iter * (args.steps - step) / 3600,
            )
            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()
        if step % 10000 == 0 and rank == 0 and step != 0:
            logger.info("SAVING %06d", step)
            mge.save(
                model.state_dict(),
                os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step)),
            )
        if step % 50000 == 0 and step != 0:
            _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logger.info(
                "TEST Iter %06d: loss = %f,\tTop-1 err = %f,\tTop-5 err = %f",
                step, _, 1 - valid_acc / 100, 1 - valid_acc5 / 100)

    mge.save(model.state_dict(),
             os.path.join(save_dir, "checkpoint-{:06d}.pkl".format(step)))
    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("TEST Iter %06d: loss=%f,\tTop-1 err = %f,\tTop-5 err = %f",
                step, _, 1 - valid_acc / 100, 1 - valid_acc5 / 100)
Ejemplo n.º 14
0
def worker(master_ip, port, rank, world_size, args):
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(rank, world_size))
        dist.init_process_group(
            master_ip=master_ip,
            port=port,
            world_size=world_size,
            rank=rank,
            device=rank,
        )

    model_name = "{}_{}x{}".format(args.arch, cfg.input_shape[0], cfg.input_shape[1])
    save_dir = os.path.join(args.save, model_name)

    model = getattr(kpm, args.arch)()
    model.train()
    start_epoch = 0
    if args.resume is not None:
        file = mge.load(args.resume)
        model.load_state_dict(file["state_dict"])
        start_epoch = file["epoch"]

    optimizer = optim.Adam(
        model.parameters(), lr=cfg.initial_lr, weight_decay=cfg.weight_decay
    )

    gm = GradManager()
    if dist.get_world_size() > 1:
        gm.attach(
            model.parameters(), callbacks=[dist.make_allreduce_cb("SUM", dist.WORLD)],
        )
    else:
        gm.attach(model.parameters())

    if dist.get_world_size() > 1:
        dist.bcast_list_(model.parameters(), dist.WORLD)  # sync parameters

    # Build train datasets
    logger.info("preparing dataset..")
    ann_file = os.path.join(
        cfg.data_root, "annotations", "person_keypoints_train2017.json"
    )
    train_dataset = COCOJoints(
        cfg.data_root,
        ann_file,
        image_set="train2017",
        order=("image", "keypoints", "boxes", "info"),
    )
    logger.info("Num of Samples: {}".format(len(train_dataset)))
    train_sampler = data.RandomSampler(
        train_dataset, batch_size=cfg.batch_size, drop_last=True
    )

    transforms = [
        T.Normalize(mean=cfg.img_mean, std=cfg.img_std),
        RandomHorizontalFlip(0.5, keypoint_flip_order=cfg.keypoint_flip_order)
    ]

    if cfg.half_body_transform:
        transforms.append(
            HalfBodyTransform(
                cfg.upper_body_ids, cfg.lower_body_ids, cfg.prob_half_body
            )
        )
    if cfg.extend_boxes:
        transforms.append(
            ExtendBoxes(cfg.x_ext, cfg.y_ext, cfg.input_shape[1] / cfg.input_shape[0])
        )

    transforms += [
        RandomBoxAffine(
            degrees=cfg.rotate_range,
            scale=cfg.scale_range,
            output_shape=cfg.input_shape,
            rotate_prob=cfg.rotation_prob,
            scale_prob=cfg.scale_prob,
        )
    ]
    transforms += [T.ToMode()]

    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        num_workers=args.workers,
        transform=T.Compose(transforms=transforms, order=train_dataset.order,),
        collator=HeatmapCollator(
            cfg.input_shape,
            cfg.output_shape,
            cfg.keypoint_num,
            cfg.heat_thr,
            cfg.heat_kernels if args.multi_scale_supervision else cfg.heat_kernels[-1:],
            cfg.heat_range,
        ),
    )

    # Start training
    for epoch in range(start_epoch, cfg.epochs):
        loss = train(model, train_queue, optimizer, gm, epoch=epoch)
        logger.info("Epoch %d Train %.6f ", epoch, loss)

        if rank == 0 and epoch % cfg.save_freq == 0:  # save checkpoint
            mge.save(
                {"epoch": epoch + 1, "state_dict": model.state_dict()},
                os.path.join(save_dir, "epoch_{}.pkl".format(epoch)),
            )
Ejemplo n.º 15
0
def worker(rank, world_size, args):
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    model_name = "{}_{}x{}".format(args.arch, cfg.input_shape[0],
                                   cfg.input_shape[1])
    save_dir = os.path.join(args.save, model_name)

    model = getattr(M, args.arch)(pretrained=args.pretrained)
    model.train()
    start_epoch = 0
    if args.c is not None:
        file = mge.load(args.c)
        model.load_state_dict(file["state_dict"])
        start_epoch = file["epoch"]

    optimizer = optim.Adam(
        model.parameters(requires_grad=True),
        lr=args.lr,
        weight_decay=cfg.weight_decay,
    )
    # Build train datasets
    logger.info("preparing dataset..")
    train_dataset = COCOJoints(
        args.data_root,
        args.ann_file,
        image_set="train",
        order=("image", "keypoints", "boxes", "info"),
    )
    train_sampler = data.RandomSampler(train_dataset,
                                       batch_size=args.batch_size,
                                       drop_last=True)

    transforms = [T.Normalize(mean=cfg.IMG_MEAN, std=cfg.IMG_STD)]
    if cfg.half_body_transform:
        transforms.append(
            HalfBodyTransform(cfg.upper_body_ids, cfg.lower_body_ids,
                              cfg.prob_half_body))
    if cfg.extend_boxes:
        transforms.append(
            ExtendBoxes(cfg.x_ext, cfg.y_ext,
                        cfg.input_shape[1] / cfg.input_shape[0]))
    transforms += [
        RandomHorizontalFlip(0.5, keypoint_flip_order=cfg.keypoint_flip_order)
    ]
    transforms += [
        RandomBoxAffine(
            degrees=cfg.rotate_range,
            scale=cfg.scale_range,
            output_shape=cfg.input_shape,
            rotate_prob=cfg.rotation_prob,
            scale_prob=cfg.scale_prob,
        )
    ]
    transforms += [T.ToMode()]

    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        num_workers=args.workers,
        transform=T.Compose(
            transforms=transforms,
            order=train_dataset.order,
        ),
        collator=HeatmapCollator(
            cfg.input_shape,
            cfg.output_shape,
            cfg.keypoint_num,
            cfg.heat_thre,
            cfg.heat_kernel
            if args.multi_scale_supervision else cfg.heat_kernel[-1:],
            cfg.heat_range,
        ),
    )

    # Start training
    for epoch in range(start_epoch, args.epochs):
        loss = train(model, train_queue, optimizer, args, epoch=epoch)
        logger.info("Epoch %d Train %.6f ", epoch, loss)

        if rank == 0:  # save checkpoint
            mge.save(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict(),
                },
                os.path.join(save_dir, "epoch_{}.pkl".format(epoch)),
            )
Ejemplo n.º 16
0
def worker(rank, world_size, args):
    # pylint: disable=too-many-statements

    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch + "." + args.mode)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    model = models.__dict__[args.arch]()
    cfg = config.get_finetune_config(args.arch)

    cfg.LEARNING_RATE *= world_size  # scale learning rate in distributed training
    total_batch_size = cfg.BATCH_SIZE * world_size
    steps_per_epoch = 1280000 // total_batch_size
    total_steps = steps_per_epoch * cfg.EPOCHS

    if args.mode != "normal":
        Q.quantize_qat(model, Q.ema_fakequant_qconfig)

    if args.checkpoint:
        logger.info("Load pretrained weights from %s", args.checkpoint)
        ckpt = mge.load(args.checkpoint)
        ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
        model.load_state_dict(ckpt, strict=False)

    if args.mode == "quantized":
        raise ValueError("mode = quantized only used during inference")
        Q.quantize(model)

    optimizer = optim.SGD(
        get_parameters(model, cfg),
        lr=cfg.LEARNING_RATE,
        momentum=cfg.MOMENTUM,
    )

    # Define train and valid graph
    @jit.trace(symbolic=True)
    def train_func(image, label):
        model.train()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        optimizer.backward(loss)  # compute gradients
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "train_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "train_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "train_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    @jit.trace(symbolic=True)
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    # Build train and valid datasets
    logger.info("preparing dataset..")
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset,
                           batch_size=cfg.BATCH_SIZE,
                           drop_last=True))
    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            cfg.COLOR_JITTOR,
            T.Normalize(mean=128),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    train_queue = iter(train_queue)
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=128),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )

    def adjust_learning_rate(step, epoch):
        learning_rate = cfg.LEARNING_RATE
        if cfg.SCHEDULER == "Linear":
            learning_rate *= 1 - float(step) / total_steps
        elif cfg.SCHEDULER == "Multistep":
            learning_rate *= cfg.SCHEDULER_GAMMA**bisect.bisect_right(
                cfg.SCHEDULER_STEPS, epoch)
        else:
            raise ValueError(cfg.SCHEDULER)
        for param_group in optimizer.param_groups:
            param_group["lr"] = learning_rate
        return learning_rate

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()
    for step in range(0, total_steps):
        # Linear learning rate decay
        epoch = step // steps_per_epoch
        learning_rate = adjust_learning_rate(step, epoch)

        image, label = next(train_queue)
        image = image.astype("float32")
        label = label.astype("int32")

        n = image.shape[0]

        optimizer.zero_grad()
        loss, acc1, acc5 = train_func(image, label)
        optimizer.step()

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logger.info("TRAIN e%d %06d %f %s %s %s %s", epoch, step,
                        learning_rate, objs, top1, top5, total_time)
            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()
        if step % 10000 == 0 and rank == 0:
            logger.info("SAVING %06d", step)
            mge.save(
                {
                    "step": step,
                    "state_dict": model.state_dict()
                },
                os.path.join(save_dir, "checkpoint.pkl"),
            )
        if step % 10000 == 0 and step != 0:
            _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)

    mge.save({
        "step": step,
        "state_dict": model.state_dict()
    }, os.path.join(save_dir, "checkpoint-final.pkl"))
    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
Ejemplo n.º 17
0
def worker(master_ip, port, world_size, rank, configs):
    if world_size > 1:
        dist.init_process_group(
            master_ip=master_ip,
            port=port,
            world_size=world_size,
            rank=rank,
            device=rank,
        )
        logger.info("init process group for gpu{} done".format(rank))

    # set up logger
    os.makedirs(configs["base_dir"], exist_ok=True)
    worklog_path = os.path.join(configs["base_dir"], "worklog.txt")
    mge.set_log_file(worklog_path)

    # prepare model-related components
    model = FaceRecognitionModel(configs)

    # prepare data-related components
    preprocess = T.Compose([T.Normalize(mean=127.5, std=128), T.ToMode("CHW")])
    augment = T.Compose([T.RandomHorizontalFlip()])

    train_dataset = get_train_dataset(configs["dataset"],
                                      dataset_dir=configs["dataset_dir"])
    train_sampler = data.RandomSampler(train_dataset,
                                       batch_size=configs["batch_size"],
                                       drop_last=True)
    train_queue = data.DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  transform=T.Compose([augment, preprocess]))

    # prepare optimize-related components
    configs["learning_rate"] = configs["learning_rate"] * dist.get_world_size()
    if dist.get_world_size() > 1:
        dist.bcast_list_(model.parameters())
        gm = ad.GradManager().attach(
            model.parameters(), callbacks=[dist.make_allreduce_cb("mean")])
    else:
        gm = ad.GradManager().attach(model.parameters())
    opt = optim.SGD(
        model.parameters(),
        lr=configs["learning_rate"],
        momentum=configs["momentum"],
        weight_decay=configs["weight_decay"],
    )

    # try to load checkpoint
    model, start_epoch = try_load_latest_checkpoint(model, configs["base_dir"])

    # do training
    def train_one_epoch():
        def train_func(images, labels):
            opt.clear_grad()
            with gm:
                loss, accuracy, _ = model(images, labels)
                gm.backward(loss)
                if dist.is_distributed():
                    # all_reduce_mean
                    loss = dist.functional.all_reduce_sum(
                        loss) / dist.get_world_size()
                    accuracy = dist.functional.all_reduce_sum(
                        accuracy) / dist.get_world_size()
            opt.step()
            return loss, accuracy

        model.train()

        average_loss = AverageMeter("loss")
        average_accuracy = AverageMeter("accuracy")
        data_time = AverageMeter("data_time")
        train_time = AverageMeter("train_time")

        total_step = len(train_queue)
        data_iter = iter(train_queue)
        for step in range(total_step):
            # get next batch of data
            data_tic = time.time()
            images, labels = next(data_iter)
            data_toc = time.time()

            # forward pass & backward pass
            train_tic = time.time()
            images = mge.tensor(images, dtype="float32")
            labels = mge.tensor(labels, dtype="int32")
            loss, accuracy = train_func(images, labels)
            train_toc = time.time()

            # do the statistics and logging
            n = images.shape[0]
            average_loss.update(loss.item(), n)
            average_accuracy.update(accuracy.item() * 100, n)
            data_time.update(data_toc - data_tic)
            train_time.update(train_toc - train_tic)
            if step % configs["log_interval"] == 0 and dist.get_rank() == 0:
                logger.info(
                    "epoch: %d, step: %d, %s, %s, %s, %s",
                    epoch,
                    step,
                    average_loss,
                    average_accuracy,
                    data_time,
                    train_time,
                )

    for epoch in range(start_epoch, configs["num_epoch"]):
        adjust_learning_rate(opt, epoch, configs)
        train_one_epoch()

        if dist.get_rank() == 0:
            checkpoint_path = os.path.join(configs["base_dir"],
                                           f"epoch-{epoch+1}-checkpoint.pkl")
            mge.save(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict()
                },
                checkpoint_path,
            )
Ejemplo n.º 18
0
def worker(rank, world_size, args):
    # pylint: disable=too-many-statements

    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    model = models.__dict__[args.arch]()

    if args.mode != "normal":
        Q.quantize_qat(model, Q.ema_fakequant_qconfig)

    if args.checkpoint:
        logger.info("Load pretrained weights from %s", args.checkpoint)
        ckpt = mge.load(args.checkpoint)
        ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
        model.load_state_dict(ckpt, strict=False)

    if args.mode == "quantized":
        Q.quantize(model)

    # Define valid graph
    @jit.trace(symbolic=True)
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    # Build valid datasets
    logger.info("preparing dataset..")
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=128),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )

    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("TEST %f, %f", valid_acc, valid_acc5)
Ejemplo n.º 19
0
def worker(rank, world_size, args):
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch)

    model = getattr(M, args.arch)()

    optimizer = optim.SGD(
        model.parameters(requires_grad=True),
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    scheduler = optim.MultiStepLR(optimizer, [30, 60, 80])

    # Define train and valid graph
    @jit.trace(symbolic=True)
    def train_func(image, label):
        model.train()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        optimizer.backward(loss)  # compute gradients
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "train_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "train_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "train_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    @jit.trace(symbolic=True)
    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss,
                                       "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1,
                                       "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5,
                                       "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5

    # Build train and valid datasets
    logger.info("preparing dataset..")
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.RandomSampler(train_dataset,
                                       batch_size=args.batch_size,
                                       drop_last=True)
    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([  # Baseline Augmentation for small models
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.Normalize(mean=[103.530, 116.280, 123.675],
                        std=[57.375, 57.120, 58.395]),  # BGR
            T.ToMode("CHW"),
        ]) if args.arch in ("resnet18", "resnet34") else T.Compose(
            [  # Facebook Augmentation for large models
                T.RandomResizedCrop(224),
                T.RandomHorizontalFlip(),
                T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                T.Lighting(0.1),
                T.Normalize(mean=[103.530, 116.280, 123.675],
                            std=[57.375, 57.120, 58.395]),  # BGR
                T.ToMode("CHW"),
            ]),
        num_workers=args.workers,
    )
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=[103.530, 116.280, 123.675],
                        std=[57.375, 57.120, 58.395]),  # BGR
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )

    # Start training
    top1_acc = 0
    for epoch in range(0, args.epochs):
        logger.info("Epoch %d LR %.3e", epoch, scheduler.get_lr()[0])
        _, train_acc, train_acc5 = train(train_func,
                                         train_queue,
                                         optimizer,
                                         args,
                                         epoch=epoch)
        logger.info("Epoch %d Train %.3f / %.3f", epoch, train_acc, train_acc5)
        _, valid_acc, valid_acc5 = infer(valid_func,
                                         valid_queue,
                                         args,
                                         epoch=epoch)
        logger.info("Epoch %d Valid %.3f / %.3f", epoch, valid_acc, valid_acc5)
        scheduler.step()
        if rank == 0:  # save checkpoint
            mge.save(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict(),
                    "accuracy": valid_acc,
                },
                os.path.join(save_dir, "checkpoint.pkl"),
            )
            if valid_acc > top1_acc:
                top1_acc = valid_acc
                shutil.copy(
                    os.path.join(save_dir, "checkpoint.pkl"),
                    os.path.join(save_dir, "model_best.pkl"),
                )
Ejemplo n.º 20
0
def worker(world_size, args):
    # pylint: disable=too-many-statements

    rank = dist.get_rank()
    if world_size > 1:
        logger.info("init distributed process group {} / {}".format(
            rank, world_size))

    save_dir = os.path.join(args.save, args.arch + "." + args.mode)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    model = models.__dict__[args.arch]()
    cfg = config.get_config(args.arch)

    cfg.LEARNING_RATE *= world_size  # scale learning rate in distributed training
    total_batch_size = cfg.BATCH_SIZE * world_size
    steps_per_epoch = 1280000 // total_batch_size
    total_steps = steps_per_epoch * cfg.EPOCHS

    if args.mode != "normal":
        quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)

    if world_size > 1:
        # Sync parameters
        dist.bcast_list_(model.parameters(), dist.WORLD)

    # Autodiff gradient manager
    gm = autodiff.GradManager().attach(
        model.parameters(),
        callbacks=dist.make_allreduce_cb("MEAN") if world_size > 1 else None,
    )

    optimizer = optim.SGD(
        get_parameters(model, cfg),
        lr=cfg.LEARNING_RATE,
        momentum=cfg.MOMENTUM,
    )

    # Define train and valid graph
    def train_func(image, label):
        with gm:
            model.train()
            logits = model(image)
            loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
            acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
            gm.backward(loss)
            optimizer.step().clear_grad()
        return loss, acc1, acc5

    def valid_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
        acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
        return loss, acc1, acc5

    # Build train and valid datasets
    logger.info("preparing dataset..")
    train_dataset = data.dataset.ImageNet(args.data, train=True)
    train_sampler = data.Infinite(
        data.RandomSampler(train_dataset,
                           batch_size=cfg.BATCH_SIZE,
                           drop_last=True))
    train_queue = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        transform=T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            cfg.COLOR_JITTOR,
            T.Normalize(mean=128),
            T.ToMode("CHW"),
        ]),
        num_workers=args.workers,
    )
    train_queue = iter(train_queue)
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(valid_dataset,
                                           batch_size=100,
                                           drop_last=False)
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.Normalize(mean=128),
            T.ToMode("CHW")
        ]),
        num_workers=args.workers,
    )

    def adjust_learning_rate(step, epoch):
        learning_rate = cfg.LEARNING_RATE
        if cfg.SCHEDULER == "Linear":
            learning_rate *= 1 - float(step) / total_steps
        elif cfg.SCHEDULER == "Multistep":
            learning_rate *= cfg.SCHEDULER_GAMMA**bisect.bisect_right(
                cfg.SCHEDULER_STEPS, epoch)
        else:
            raise ValueError(cfg.SCHEDULER)
        for param_group in optimizer.param_groups:
            param_group["lr"] = learning_rate
        return learning_rate

    # Start training
    objs = AverageMeter("Loss")
    top1 = AverageMeter("Acc@1")
    top5 = AverageMeter("Acc@5")
    total_time = AverageMeter("Time")

    t = time.time()
    for step in range(0, total_steps):
        # Linear learning rate decay
        epoch = step // steps_per_epoch
        learning_rate = adjust_learning_rate(step, epoch)

        image, label = next(train_queue)
        image = mge.tensor(image, dtype="float32")
        label = mge.tensor(label, dtype="int32")

        n = image.shape[0]

        loss, acc1, acc5 = train_func(image, label)

        top1.update(100 * acc1.numpy()[0], n)
        top5.update(100 * acc5.numpy()[0], n)
        objs.update(loss.numpy()[0], n)
        total_time.update(time.time() - t)
        t = time.time()
        if step % args.report_freq == 0 and rank == 0:
            logger.info(
                "TRAIN e%d %06d %f %s %s %s %s",
                epoch,
                step,
                learning_rate,
                objs,
                top1,
                top5,
                total_time,
            )
            objs.reset()
            top1.reset()
            top5.reset()
            total_time.reset()
        if step != 0 and step % 10000 == 0 and rank == 0:
            logger.info("SAVING %06d", step)
            mge.save(
                {
                    "step": step,
                    "state_dict": model.state_dict()
                },
                os.path.join(save_dir, "checkpoint.pkl"),
            )
        if step % 10000 == 0 and step != 0:
            _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
            logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)

    mge.save(
        {
            "step": step,
            "state_dict": model.state_dict()
        },
        os.path.join(save_dir, "checkpoint-final.pkl"),
    )
    _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args)
    logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5)
Ejemplo n.º 21
0
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ------------------------------------------------------------------------------
import megengine.data as data
import megengine.data.transform as T
import megengine.optimizer as optim

import megengine_mimicry as mmc
import megengine_mimicry.nets.dcgan.dcgan_cifar as dcgan

dataset = mmc.datasets.load_dataset(root=None, name='cifar10')
dataloader = data.DataLoader(dataset,
                             sampler=data.Infinite(
                                 data.RandomSampler(dataset,
                                                    batch_size=64,
                                                    drop_last=True)),
                             transform=T.Compose(
                                 [T.Normalize(std=255),
                                  T.ToMode("CHW")]),
                             num_workers=4)

netG = dcgan.DCGANGeneratorCIFAR()
netD = dcgan.DCGANDiscriminatorCIFAR()
optD = optim.Adam(netD.parameters(), 2e-4, betas=(0.0, 0.9))
optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9))

LOG_DIR = "./log/dcgan_example"

trainer = mmc.training.Trainer(netD=netD,
                               netG=netG,
                               optD=optD,
Ejemplo n.º 22
0
def worker(rank, world_size, args):
    # pylint: disable=too-many-statements

    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(rank, world_size))
        dist.init_process_group(
            master_ip="localhost",
            master_port=23456,
            world_size=world_size,
            rank=rank,
            dev=rank,
        )

    save_dir = os.path.join(args.save, args.arch + "." + args.mode)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    model = models.__dict__[args.arch]()
    cfg = config.get_finetune_config(args.arch)

    cfg.LEARNING_RATE *= world_size  # scale learning rate in distributed training
    total_batch_size = cfg.BATCH_SIZE * world_size
    steps_per_epoch = 1280000 // total_batch_size
    total_steps = steps_per_epoch * cfg.EPOCHS
    
    # load calibration model
    assert args.checkpoint
    logger.info("Load pretrained weights from %s", args.checkpoint)
    ckpt = mge.load(args.checkpoint)
    ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    model.load_state_dict(ckpt, strict=False)

    # Build valid datasets
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    # valid_dataset = ImageNetNoriDataset(args.data)
    valid_sampler = data.SequentialSampler(
        valid_dataset, batch_size=100, drop_last=False
    )
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose(
            [
                T.Resize(256),
                T.CenterCrop(224),
                T.Normalize(mean=128),
                T.ToMode("CHW"),
            ]
        ),
        num_workers=args.workers,
    )

    # calibration
    model.fc.disable_quantize()
    model = quantize_qat(model, qconfig=Q.calibration_qconfig)
    
    # calculate scale
    @jit.trace(symbolic=True)
    def calculate_scale(image, label):
        model.eval()
        enable_observer(model)
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5
    
    # model.fc.disable_quantize()
    infer(calculate_scale, valid_queue, args)

    # quantized
    model = quantize(model)

    # eval quantized model
    @jit.trace(symbolic=True)
    def eval_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
        acc1, acc5 = F.accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
            acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
            acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
        return loss, acc1, acc5
        
    _, valid_acc, valid_acc5 = infer(eval_func, valid_queue, args)
    logger.info("TEST %f, %f", valid_acc, valid_acc5)

    # save quantized model
    mge.save(
        {"step": -1, "state_dict": model.state_dict()},
        os.path.join(save_dir, "checkpoint-calibration.pkl")
    )
    logger.info("save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl")))
Ejemplo n.º 23
0
def worker(world_size, args):
    # pylint: disable=too-many-statements

    rank = dist.get_rank()
    if world_size > 1:
        # Initialize distributed process group
        logger.info("init distributed process group {} / {}".format(rank, world_size))

    save_dir = os.path.join(args.save, args.arch + "." + "calibration")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    mge.set_log_file(os.path.join(save_dir, "log.txt"))

    model = models.__dict__[args.arch]()

    # load calibration model
    assert args.checkpoint
    logger.info("Load pretrained weights from %s", args.checkpoint)
    ckpt = mge.load(args.checkpoint)
    ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    model.load_state_dict(ckpt, strict=False)

    # Build valid datasets
    valid_dataset = data.dataset.ImageNet(args.data, train=False)
    valid_sampler = data.SequentialSampler(
        valid_dataset, batch_size=100, drop_last=False
    )
    valid_queue = data.DataLoader(
        valid_dataset,
        sampler=valid_sampler,
        transform=T.Compose(
            [T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW")]
        ),
        num_workers=args.workers,
    )

    # calibration
    model.fc.disable_quantize()
    model = quantize_qat(model, qconfig=Q.calibration_qconfig)

    # calculate scale
    def calculate_scale(image, label):
        model.eval()
        enable_observer(model)
        logits = model(image)
        loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
        acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    infer(calculate_scale, valid_queue, args)

    # quantized
    model = quantize(model)

    # eval quantized model
    def eval_func(image, label):
        model.eval()
        logits = model(image)
        loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
        acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
        if dist.is_distributed():  # all_reduce_mean
            loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
            acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
            acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
        return loss, acc1, acc5

    _, valid_acc, valid_acc5 = infer(eval_func, valid_queue, args)
    logger.info("TEST %f, %f", valid_acc, valid_acc5)

    # save quantized model
    mge.save(
        {"step": -1, "state_dict": model.state_dict()},
        os.path.join(save_dir, "checkpoint-calibration.pkl"),
    )
    logger.info(
        "save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl"))
    )
Ejemplo n.º 24
0
def worker(rank, gpu_num, args):
    # using sublinear
    os.environ["MGB_COMP_GRAPH_OPT"] = "enable_sublinear_memory_opt=1;seq_opt.enable_seq_comp_node_opt=0"
    os.environ["MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER"] = '10'
    os.environ['MGB_CUDA_RESERVE_MEMORY'] = '1'
    # establish the server if is the master

    dist_port = args.port
    if rank == 0:
        dist.Server(port=dist_port)
    if gpu_num> 1:

        dist.init_process_group(
            master_ip="localhost",
            port=dist_port,
            world_size=gpu_num,
            rank=rank,
            device=rank,
        )
        logger.info("Init process group for gpu%d done", rank)

    model = network.Network()
    params = model.parameters(requires_grad=True)
    model.train()

    # Autodiff gradient manager
    gm = autodiff.GradManager().attach(
        model.parameters(),
        callbacks=allreduce_cb,
    )

    opt = optim.SGD(
        params,
        lr=cfg.basic_lr * gpu_num * cfg.batch_per_gpu,
        momentum=cfg.momentum,
        weight_decay=cfg.weight_decay,
    )

    if cfg.pretrain_weight is not None:
        weights = mge.load(cfg.pretrain_weight)
        del weights['fc.weight']
        del weights['fc.bias']
        model.resnet50.load_state_dict(weights)

    start_epoch = 0
    if args.resume_weights is not None:
        assert osp.exists(args.resume_weights)
        model_file = args.resume_weights
        print('Loading {} to initialize FPN...'.format(model_file))
        model_dict = mge.load(model_file)
        start_epoch, weights = model_dict['epoch'] + 1, model_dict['state_dict']
        model.load_state_dict(weights, strict=False)
    
    logger.info("Prepare dataset")
    # train_loader = dataset.train_dataset(rank)

    train_dataset = CrowdHuman(cfg, if_train=True)
    train_sampler = data.Infinite(data.RandomSampler(
        train_dataset, batch_size = cfg.batch_per_gpu, drop_last=True,
        world_size = gpu_num, rank = rank,))
    train_loader = data.DataLoader(
        train_dataset,
        sampler=train_sampler,
        collator = train_dataset,
        num_workers=4,
    )
    
    train_loader = iter(train_loader)
    logger.info("Training...")
    for epoch_id in range(start_epoch, cfg.max_epoch):
        for param_group in opt.param_groups:
            param_group["lr"] = (
                cfg.basic_lr * gpu_num * cfg.batch_per_gpu
                * (cfg.lr_decay_rate ** bisect.bisect_right(cfg.lr_decay_sates, epoch_id))
            )

        max_steps = cfg.nr_images_epoch // (cfg.batch_per_gpu * gpu_num)
        train_one_epoch(model, gm, train_loader, opt, max_steps, rank, epoch_id, gpu_num)
        if rank == 0:
            save_path = osp.join(cfg.model_dir, 'epoch-{}.pkl'.format(epoch_id + 1))
            state_dict = model.state_dict()
            names = [k for k, _ in state_dict.items()]
            for name in names:
                if name.startswith('inputs.'):
                    del state_dict[name]

            mge.save(
                {"epoch": epoch_id, "state_dict": state_dict}, save_path,
            )
            logger.info("dump weights to %s", save_path)