Beispiel #1
0
def setup_distributed(num_images=None):
    """Setup distributed related parameters."""
    # init distributed
    if FLAGS.use_distributed:
        udist.init_dist()
        FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = \
                FLAGS.bn_calibration_per_gpu_batch_size
        FLAGS.data_loader_workers = round(
            FLAGS.data_loader_workers / udist.get_local_size()
        )  # Per_gpu_workers(the function will return the nearest integer
    else:
        count = torch.cuda.device_count()
        FLAGS.batch_size = count * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = \
                FLAGS.bn_calibration_per_gpu_batch_size * count
    if hasattr(FLAGS, 'base_lr'):
        FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch)
    if num_images:
        # NOTE: don't drop last batch, thus must use ceil, otherwise learning
        # rate will be negative
        # the smallest integer not less than x
        FLAGS._steps_per_epoch = math.ceil(num_images / FLAGS.batch_size)
Beispiel #2
0
    def _train_one_batch(self, x, y, optimizer, lr_scheduler, meters, criterions, end):
        top1_meter, top5_meter, loss_meter, data_time = meters
        criterion = criterions[0]
        world_size = dist.get_world_size()

        lr_scheduler.step(self.cur_step)
        self.cur_step += 1
        data_time.update(time.time() - end)

        self.model.zero_grad()
        out = self.model(x)
        loss = criterion(out, y)
        loss /= world_size

        top1, top5 = accuracy(out, y, top_k=(1, 5))
        reduced_loss = dist.all_reduce(loss.clone())
        reduced_top1 = dist.all_reduce(top1.clone(), div=True)
        reduced_top5 = dist.all_reduce(top5.clone(), div=True)

        loss_meter.update(reduced_loss.item())
        top1_meter.update(reduced_top1.item())
        top5_meter.update(reduced_top5.item())

        loss.backward()
        dist.average_gradient(self.model.parameters())
        optimizer.step()
Beispiel #3
0
def reduce_and_flush_meters(meters, method='avg'):
    """Sync and flush meters."""
    if not FLAGS.use_distributed:
        results = flush_scalar_meters(meters)
    else:
        results = {}
        assert isinstance(meters, dict), "meters should be a dict."
        # NOTE: Ensure same order, otherwise may deadlock
        for name in sorted(meters.keys()):
            meter = meters[name]
            if not isinstance(meter, ScalarMeter):
                continue
            if method == 'avg':
                method_fun = torch.mean
            elif method == 'sum':
                method_fun = torch.sum
            elif method == 'max':
                method_fun = torch.max
            elif method == 'min':
                method_fun = torch.min
            else:
                raise NotImplementedError(
                    'flush method: {} is not yet implemented.'.format(method))
            tensor = torch.tensor(meter.values).cuda()
            gather_tensors = [
                torch.ones_like(tensor) for _ in range(udist.get_world_size())
            ]
            dist.all_gather(gather_tensors, tensor)
            value = method_fun(torch.cat(gather_tensors))
            meter.flush(value)
            results[name] = value
    return results
Beispiel #4
0
def reduce_tensor(inp):
    """
    Reduce the loss from all processes so that 
    process with rank 0 has the averaged results.
    """
    world_size = dist.get_world_size()
    if world_size < 2:
        return inp
    with torch.no_grad():
        reduced_inp = inp
        torch.distributed.reduce(reduced_inp, dst=0)
    return reduced_inp / world_size
Beispiel #5
0
def main():
    """Entry."""
    # init distributed
    global is_root_rank
    if FLAGS.use_distributed:
        udist.init_dist()
        FLAGS.batch_size = udist.get_world_size() * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.per_gpu_batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size
        FLAGS.data_loader_workers = round(FLAGS.data_loader_workers /
                                          udist.get_local_size())
        is_root_rank = udist.is_master()
    else:
        count = torch.cuda.device_count()
        FLAGS.batch_size = count * FLAGS.per_gpu_batch_size
        FLAGS._loader_batch_size = FLAGS.batch_size
        if FLAGS.bn_calibration:
            FLAGS._loader_batch_size_calib = FLAGS.bn_calibration_per_gpu_batch_size * count
        is_root_rank = True
    FLAGS.lr = FLAGS.base_lr * (FLAGS.batch_size / FLAGS.base_total_batch)
    # NOTE: don't drop last batch, thus must use ceil, otherwise learning rate
    # will be negative
    FLAGS._steps_per_epoch = int(np.ceil(NUM_IMAGENET_TRAIN /
                                         FLAGS.batch_size))

    if is_root_rank:
        FLAGS.log_dir = '{}/{}'.format(FLAGS.log_dir,
                                       time.strftime("%Y%m%d-%H%M%S"))
        create_exp_dir(
            FLAGS.log_dir,
            FLAGS.config_path,
            blacklist_dirs=[
                'exp',
                '.git',
                'pretrained',
                'tmp',
                'deprecated',
                'bak',
            ],
        )
        setup_logging(FLAGS.log_dir)
        for k, v in _ENV_EXPAND.items():
            logging.info('Env var expand: {} to {}'.format(k, v))
        logging.info(FLAGS)

    set_random_seed(FLAGS.get('random_seed', 0))
    with SummaryWriterManager():
        train_val_test()
Beispiel #6
0
def build_data_loader():
    logger.info("build train dataset")
    # train_dataset
    train_dataset = TrainDataset()
    logger.info("build dataset done")

    train_sampler = None
    if get_world_size() > 1:
        train_sampler = DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=cfg.TRAIN.BATCH_SIZE,
                                  num_workers=cfg.TRAIN.NUM_WORKERS,
                                  pin_memory=True,
                                  sampler=train_sampler)
    return train_dataloader
Beispiel #7
0
def KineticsSounds(cfg, split):
    if split == 'train':
        max_idx = 19
    elif split == 'val':
        max_idx = 1
    elif split == 'test':
        max_idx = 2
    dataset_root = cfg.DATASET_ROOT
    if dataset_root.endswith('/'):
        dataset_root = dataset_root[:-1]
    url = f"{dataset_root}/KineticsSounds/shards-{split}/shard-{{000000..{max_idx:06d}}}.tar"
    if cfg.STORAGE_SAS_KEY:
        url += cfg.STORAGE_SAS_KEY

    _decoder = Decoder(cfg, "KineticsSounds", split)
    if split == 'train':
        batch_size = int(cfg.TRAIN.BATCH_SIZE /
                         cfg.SOLVER.GRADIENT_ACCUMULATION_STEPS)
        batch_size = int(batch_size / du.get_world_size())
        length = int(cfg.TRAIN.DATASET_SIZE / du.get_world_size())
        nominal = int(length / batch_size)
    elif split == 'val':
        batch_size = int(cfg.TRAIN.BATCH_SIZE / du.get_world_size())
        length = int(cfg.VAL.DATASET_SIZE / du.get_world_size())
        nominal = int(length / batch_size)
    elif split == 'test':
        batch_size = int(cfg.TEST.BATCH_SIZE / du.get_world_size())
        length = math.ceil(cfg.TEST.DATASET_SIZE / du.get_world_size())
        nominal = math.ceil(length / batch_size)

    wds.filters.batched = wds.filters.Curried(
        partial(wds.filters.batched_, collation_fn=COLLATE_FN["kinetics"]))

    dataset = wds.Dataset(
        url,
        handler=wds.warn_and_continue,
        shard_selection=du.shard_selection,
        length=length,
    )
    if split == 'train':
        dataset = dataset.shuffle(100)
    dataset = (dataset.map_dict(
        handler=wds.warn_and_continue,
        mp4=_decoder.mp4decode,
        json=_decoder.jsondecode,
    ))
    if cfg.DATA_LOADER.NUM_WORKERS > 0:
        length = nominal
    else:
        nominal = length
    dataset = wds.ResizedDataset(
        dataset,
        length=length,
        nominal=nominal,
    )
    return dataset
Beispiel #8
0
def check_dist_init(config, logger):
    # check distributed initialization
    if config.distributed.enable:
        import os
        # for slurm
        try:
            node_id = int(os.environ['SLURM_NODEID'])
        except KeyError:
            return

        rank = dist.get_rank()
        world_size = dist.get_world_size()
        gpu_id = dist.gpu_id

        logger.info('World: {}/Node: {}/Rank: {}/GpuId: {} initialized.'
                    .format(world_size, node_id, rank, gpu_id))
Beispiel #9
0
    def _train_one_batch(self, x, y, optimizer, lr_scheduler, meters,
                         criterions, end):
        top1_meter, top5_meter, loss_meter, data_time = meters
        criterion, distill_loss = criterions
        world_size = dist.get_world_size()
        max_width = self.config.training.sandwich.max_width

        lr_scheduler.step(self.cur_step)
        self.cur_step += 1
        data_time.update(time.time() - end)

        self.model.zero_grad()

        max_pred = None
        for idx in range(self.config.training.sandwich.num_sample):
            # sandwich rule
            top1_m, top5_m, loss_m = self._set_width(idx, top1_meter,
                                                     top5_meter, loss_meter)

            out = self.model(x)
            if self.config.training.distillation.enable:
                if idx == 0:
                    max_pred = out.detach()
                    loss = criterion(out, y)
                else:
                    loss = self.config.training.distillation.loss_weight * \
                           distill_loss(out, max_pred)
                    if self.config.training.distillation.hard_label:
                        loss += criterion(out, y)
            else:
                loss = criterion(out, y)
            loss /= world_size

            top1, top5 = accuracy(out, y, top_k=(1, 5))
            reduced_loss = dist.all_reduce(loss.clone())
            reduced_top1 = dist.all_reduce(top1.clone(), div=True)
            reduced_top5 = dist.all_reduce(top5.clone(), div=True)

            loss_m.update(reduced_loss.item())
            top1_m.update(reduced_top1.item())
            top5_m.update(reduced_top5.item())

            loss.backward()

        dist.average_gradient(self.model.parameters())
        optimizer.step()
Beispiel #10
0
    def _train_one_batch(self, x, y, optimizer, lr_scheduler, meters,
                         criterions, end):
        lr_scheduler, arch_lr_scheduler = lr_scheduler
        optimizer, arch_optimizer = optimizer
        top1_meter, top5_meter, loss_meter, arch_loss_meter, \
            floss_meter, eflops_meter, arch_top1_meter, data_time = meters
        criterion, _ = criterions

        self.model.module.set_alpha_training(False)
        super(DMCPRunner, self)._train_one_batch(
            x, y, optimizer, lr_scheduler,
            [top1_meter, top5_meter, loss_meter, data_time], criterions, end)

        arch_lr_scheduler.step(self.cur_step)
        world_size = dist.get_world_size()

        # train architecture params
        if self.cur_step >= self.config.arch.start_train \
                and self.cur_step % self.config.arch.train_freq == 0:
            self._set_width(0, top1_meter, top5_meter, loss_meter)
            self.model.module.set_alpha_training(True)

            self.model.zero_grad()
            arch_out = self.model(x)
            arch_loss = criterion(arch_out, y)
            arch_loss /= world_size
            floss, eflops = flop_loss(self.config, self.model)
            floss /= world_size

            arch_top1 = accuracy(arch_out, y, top_k=(1, ))[0]
            reduced_arch_loss = dist.all_reduce(arch_loss.clone())
            reduced_floss = dist.all_reduce(floss.clone())
            reduced_eflops = dist.all_reduce(eflops.clone(), div=True)
            reduced_arch_top1 = dist.all_reduce(arch_top1.clone(), div=True)

            arch_loss_meter.update(reduced_arch_loss.item())
            floss_meter.update(reduced_floss.item())
            eflops_meter.update(reduced_eflops.item())
            arch_top1_meter.update(reduced_arch_top1.item())

            floss.backward()
            arch_loss.backward()
            dist.average_gradient(self.model.module.arch_parameters())
            arch_optimizer.step()
Beispiel #11
0
def setup_logging(output_dir=None):
    """
    Sets up the logging for multiple processes. Only enable the logging for the
    master process, and suppress logging for the non-master processes.
    """
    # Set up logging format.
    _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s"

    if du.is_master_proc():
        # Enable logging for the master process.
        logging.root.handlers = []
        logging.basicConfig(level=logging.INFO,
                            format=_FORMAT,
                            stream=sys.stdout)
    else:
        # Suppress logging for non-master processes.
        _suppress_print()

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    logger.propagate = False
    plain_formatter = logging.Formatter(
        "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s",
        datefmt="%m/%d %H:%M:%S",
    )

    if du.is_master_proc():
        ch = logging.StreamHandler(stream=sys.stdout)
        ch.setLevel(logging.DEBUG)
        ch.setFormatter(plain_formatter)
        logger.addHandler(ch)

    if output_dir is not None and du.is_master_proc(du.get_world_size()):
        filename = os.path.join(output_dir, "stdout.log")
        fh = logging.StreamHandler(_cached_log_stream(filename))
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)
Beispiel #12
0
def train(train_dataloader, model, optimizer, lr_scheduler):
    def is_valid_number(x):
        return not (math.isnan(x) or math.isinf(x) or x > 1e4)

    logger.info("model\n{}".format(describe(model.module)))
    tb_writer = SummaryWriter(cfg.TRAIN.LOG_DIR)
    average_meter = AverageMeter()
    start_epoch = cfg.TRAIN.START_EPOCH
    world_size = get_world_size()
    num_per_epoch = len(
        train_dataloader.dataset) // (cfg.TRAIN.BATCH_SIZE * world_size)
    iter = 0
    if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR) and get_rank() == 0:
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR)
    for epoch in range(cfg.TRAIN.START_EPOCH, cfg.TRAIN.EPOCHS):
        if cfg.BACKBONE.TRAIN_EPOCH == epoch:
            logger.info('begin to train backbone!')
            optimizer, lr_scheduler = build_optimizer_lr(model.module, epoch)
            logger.info("model\n{}".format(describe(model.module)))
        train_dataloader.dataset.shuffle()
        lr_scheduler.step(epoch)
        # log for lr
        if get_rank() == 0:
            for idx, pg in enumerate(optimizer.param_groups):
                tb_writer.add_scalar('lr/group{}'.format(idx + 1), pg['lr'],
                                     iter)
        cur_lr = lr_scheduler.get_cur_lr()
        for data in train_dataloader:
            begin = time.time()
            examplar_img = data['examplar_img'].cuda()
            search_img = data['search_img'].cuda()
            gt_cls = data['gt_cls'].cuda()
            gt_delta = data['gt_delta'].cuda()
            delta_weight = data['delta_weight'].cuda()
            data_time = time.time() - begin
            losses = model.forward(examplar_img, search_img, gt_cls, gt_delta,
                                   delta_weight)
            cls_loss = losses['cls_loss']
            loc_loss = losses['loc_loss']
            loss = losses['total_loss']

            if is_valid_number(loss.item()):
                optimizer.zero_grad()
                loss.backward()
                reduce_gradients(model)
                if get_rank() == 0 and cfg.TRAIN.LOG_GRAD:
                    log_grads(model.module, tb_writer, iter)
                clip_grad_norm_(model.parameters(), cfg.TRAIN.GRAD_CLIP)
                optimizer.step()

            batch_time = time.time() - begin
            batch_info = {}
            batch_info['data_time'] = average_reduce(data_time)
            batch_info['batch_time'] = average_reduce(batch_time)
            for k, v in losses.items():
                batch_info[k] = average_reduce(v)
            average_meter.update(**batch_info)
            if get_rank() == 0:
                for k, v in batch_info.items():
                    tb_writer.add_scalar(k, v, iter)
                if iter % cfg.TRAIN.PRINT_EVERY == 0:
                    logger.info(
                        'epoch: {}, iter: {}, cur_lr:{}, cls_loss: {}, loc_loss: {}, loss: {}'
                        .format(epoch + 1, iter, cur_lr, cls_loss.item(),
                                loc_loss.item(), loss.item()))
                    print_speed(iter + 1 + start_epoch * num_per_epoch,
                                average_meter.batch_time.avg,
                                cfg.TRAIN.EPOCHS * num_per_epoch)
            iter += 1
        # save model
        if get_rank() == 0:
            state = {
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch + 1
            }
            logger.info('save snapshot to {}/checkpoint_e{}.pth'.format(
                cfg.TRAIN.SNAPSHOT_DIR, epoch + 1))
            torch.save(
                state, '{}/checkpoint_e{}.pth'.format(cfg.TRAIN.SNAPSHOT_DIR,
                                                      epoch + 1))
Beispiel #13
0
def data_loader(train_set, val_set, test_set):
    """get data loader"""
    train_loader = None
    val_loader = None
    test_loader = None
    # infer batch size
    if getattr(FLAGS, 'batch_size', False):
        if getattr(FLAGS, 'batch_size_per_gpu', False):
            assert FLAGS.batch_size == (
                FLAGS.batch_size_per_gpu * FLAGS.num_gpus_per_job)
        else:
            assert FLAGS.batch_size % FLAGS.num_gpus_per_job == 0
            FLAGS.batch_size_per_gpu = (
                FLAGS.batch_size // FLAGS.num_gpus_per_job)
    elif getattr(FLAGS, 'batch_size_per_gpu', False):
        FLAGS.batch_size = FLAGS.batch_size_per_gpu * FLAGS.num_gpus_per_job
    else:
        raise ValueError('batch size (per gpu) is not defined')
    batch_size = int(FLAGS.batch_size/get_world_size())
    if FLAGS.data_loader == 'imagenet1k_basic':
        if getattr(FLAGS, 'distributed', False):
            if FLAGS.test_only:
                train_sampler = None
            else:
                train_sampler = DistributedSampler(train_set)
            val_sampler = DistributedSampler(val_set)
        else:
            train_sampler = None
            val_sampler = None
        if not FLAGS.test_only:
            train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=batch_size,
                shuffle=(train_sampler is None),
                sampler=train_sampler,
                pin_memory=True,
                num_workers=FLAGS.data_loader_workers,
                drop_last=getattr(FLAGS, 'drop_last', False))
        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=batch_size,
            shuffle=False,
            sampler=val_sampler,
            pin_memory=True,
            num_workers=FLAGS.data_loader_workers,
            drop_last=getattr(FLAGS, 'drop_last', False))
        test_loader = val_loader
    else:
        try:
            data_loader_lib = importlib.import_module(FLAGS.data_loader)
            return data_loader_lib.data_loader(train_set, val_set, test_set)
        except ImportError:
            raise NotImplementedError(
                'Data loader {} is not yet implemented.'.format(
                    FLAGS.data_loader))
    if train_loader is not None:
        FLAGS.data_size_train = len(train_loader.dataset)
    if val_loader is not None:
        FLAGS.data_size_val = len(val_loader.dataset)
    if test_loader is not None:
        FLAGS.data_size_test = len(test_loader.dataset)
    return train_loader, val_loader, test_loader