示例#1
0
    def parallelize(model):
        if world_size <= 1:
            return model

        if use_gpu:
            model.top_model = parallel.DistributedDataParallel(model.top_model)
        else:  # Use other backend for CPU
            model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model)
        return model
示例#2
0
def set_model_dist(net):
    if has_apex:
        net = parallel.DistributedDataParallel(net, delay_allreduce=True)
    else:
        local_rank = dist.get_rank()
        net = nn.parallel.DistributedDataParallel(
            net,
            device_ids=[local_rank, ],
            output_device=local_rank)
    return net
示例#3
0
def main(argv):
    rank, world_size, gpu = dist.init_distributed_mode()

    top_mlp = create_top_mlp().to("cuda")
    print(top_mlp)

    optimizer = torch.optim.SGD(top_mlp.parameters(), lr=1.)

    if FLAGS.fp16:
        top_mlp, optimizer = amp.initialize(top_mlp,
                                            optimizer,
                                            opt_level="O1",
                                            loss_scale=1)

    if world_size > 1:
        top_mlp = parallel.DistributedDataParallel(top_mlp)
        model_without_ddp = top_mlp.module

    dummy_bottom_mlp_output = torch.rand(FLAGS.batch_size,
                                         EMBED_DIM,
                                         device="cuda")
    dummy_embedding_output = torch.rand(FLAGS.batch_size,
                                        26 * EMBED_DIM,
                                        device="cuda")
    dummy_target = torch.ones(FLAGS.batch_size, device="cuda")

    if FLAGS.fp16:
        dummy_bottom_mlp_output = dummy_bottom_mlp_output.to(torch.half)
        dummy_embedding_output = dummy_embedding_output.to(torch.half)

    # warm up GPU
    for _ in range(100):
        interaction_out = dot_interaction(dummy_bottom_mlp_output,
                                          [dummy_embedding_output],
                                          FLAGS.batch_size)
        output = top_mlp(interaction_out)

    start_time = utils.timer_start()
    for _ in range(FLAGS.num_iters):
        interaction_out = dot_interaction(dummy_bottom_mlp_output,
                                          [dummy_embedding_output],
                                          FLAGS.batch_size)
        output = top_mlp(interaction_out).squeeze()
        dummy_loss = output.mean()
        optimizer.zero_grad()
        if FLAGS.fp16:
            with amp.scale_loss(dummy_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            dummy_loss.backward()
        optimizer.step()
    stop_time = utils.timer_stop()

    elapsed_time = (stop_time - start_time) / FLAGS.num_iters * 1e3
    print(F"Average step time: {elapsed_time:.4f} ms.")
示例#4
0
def train(rank):
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    dist.init_process_group(backend="nccl",
                            init_method="tcp://localhost:34567",
                            world_size=8,
                            rank=rank)
    dist.barrier()

    model = deeplab.resnext101_aspp_kp(19)
    torch.cuda.set_device(rank)
    if rank == 0:
        writer = SummaryWriter(log_dir=args.checkpoint_dir, flush_secs=20)
    model = parallel.convert_syncbn_model(model)
    model.cuda(rank)

    model.load_state_dict(
        torch.load(args.model_dir / "resnext_cityscapes_2p.pth",
                   map_location=f"cuda:{rank}"),
        strict=False,
    )
    dist.barrier()
    if rank == 0:
        print(model.parameters)
    model = parallel.DistributedDataParallel(model)
    train_dataset = semantic_kitti.SemanticKitti(
        args.semantic_kitti_dir / "dataset/sequences",
        "train",
    )

    train_sampler = utils.dist_utils.TrainingSampler(train_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=3,
        num_workers=8,
        drop_last=True,
        shuffle=False,
        pin_memory=True,
        sampler=train_sampler,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset=semantic_kitti.SemanticKitti(
            args.semantic_kitti_dir / "dataset/sequences",
            "val",
        ),
        batch_size=1,
        shuffle=False,
        num_workers=4,
        drop_last=False,
    )

    loss_fn = utils.ohem.OhemCrossEntropy(ignore_index=255,
                                          thresh=0.9,
                                          min_kept=10000)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.00001,
                                momentum=0.9,
                                weight_decay=1e-4)
    scheduler = utils.cosine_schedule.CosineAnnealingWarmUpRestarts(
        optimizer, T_0=96000, T_mult=10, eta_max=0.01875, T_up=1000, gamma=0.5)
    n_iter = 0
    for epoch in range(120):
        model.train()
        for step, items in enumerate(train_loader):
            images = items["image"].cuda(rank, non_blocking=True)
            labels = items["labels"].long().cuda(rank, non_blocking=True)
            py = items["py"].float().cuda(rank, non_blocking=True)
            px = items["px"].float().cuda(rank, non_blocking=True)
            pxyz = items["points_xyz"].float().cuda(rank, non_blocking=True)
            knns = items["knns"].long().cuda(rank, non_blocking=True)
            predictions = model(images, px, py, pxyz, knns)

            loss = loss_fn(predictions, labels)
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            optimizer.step()
            if rank == 0:
                print(
                    f"Epoch: {epoch} Iteration: {step} / {len(train_loader)} Loss: {loss.item()}"
                )
                writer.add_scalar("loss/train", loss.item(), n_iter)
                writer.add_scalar("lr", optimizer.param_groups[0]["lr"],
                                  n_iter)
            n_iter += 1
            scheduler.step()

        if rank == 0:
            if (epoch + 1) % 5 == 0:
                run_val(model, val_loader, n_iter, writer)
            torch.save(model.module.state_dict(),
                       args.checkpoint_dir / f"epoch{epoch}.pth")
def main(argv):
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)

    use_gpu = "cpu" not in FLAGS.base_device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend,
                                                       use_gpu=use_gpu)
    device = FLAGS.base_device

    if is_main_process():
        dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

        print("Command line flags:")
        pprint(FLAGS.flag_values_dict())

    print("Creating data loaders")

    FLAGS.set_default("test_batch_size",
                      FLAGS.test_batch_size // world_size * world_size)

    categorical_feature_sizes = get_categorical_feature_sizes(FLAGS)
    world_categorical_feature_sizes = np.asarray(categorical_feature_sizes)
    device_mapping = get_device_mapping(categorical_feature_sizes,
                                        num_gpus=world_size)

    batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size,
                                              num_gpus=world_size)
    batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu)))

    # sizes of embeddings for each GPU
    categorical_feature_sizes = world_categorical_feature_sizes[
        device_mapping['embedding'][rank]].tolist()

    bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping[
        'bottom_mlp'] else None

    data_loader_train, data_loader_test = get_data_loaders(
        FLAGS, device_mapping=device_mapping)

    model = DistributedDlrm(
        vectors_per_gpu=device_mapping['vectors_per_gpu'],
        embedding_device_mapping=device_mapping['embedding'],
        embedding_type=FLAGS.embedding_type,
        embedding_dim=FLAGS.embedding_dim,
        world_num_categorical_features=len(world_categorical_feature_sizes),
        categorical_feature_sizes=categorical_feature_sizes,
        num_numerical_features=FLAGS.num_numerical_features,
        hash_indices=FLAGS.hash_indices,
        bottom_mlp_sizes=bottom_mlp_sizes,
        top_mlp_sizes=FLAGS.top_mlp_sizes,
        interaction_op=FLAGS.interaction_op,
        fp16=FLAGS.amp,
        use_cpp_mlp=FLAGS.optimized_mlp,
        bottom_features_ordered=FLAGS.bottom_features_ordered,
        device=device)
    print(model)
    print(device_mapping)
    print(f"Batch sizes per gpu: {batch_sizes_per_gpu}")

    dist.setup_distributed_print(is_main_process())

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr

    if FLAGS.Adam_embedding_optimizer:
        embedding_model_parallel_lr = scaled_lr
    else:
        embedding_model_parallel_lr = scaled_lr / world_size
    if FLAGS.Adam_MLP_optimizer:
        MLP_model_parallel_lr = scaled_lr
    else:
        MLP_model_parallel_lr = scaled_lr / world_size
    data_parallel_lr = scaled_lr

    if is_main_process():
        mlp_params = [{
            'params': list(model.top_model.parameters()),
            'lr': data_parallel_lr
        }, {
            'params': list(model.bottom_model.mlp.parameters()),
            'lr': MLP_model_parallel_lr
        }]
        mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr]
    else:
        mlp_params = [{
            'params': list(model.top_model.parameters()),
            'lr': data_parallel_lr
        }]
        mlp_lrs = [data_parallel_lr]

    if FLAGS.Adam_MLP_optimizer:
        mlp_optimizer = apex_optim.FusedAdam(mlp_params)
    else:
        mlp_optimizer = apex_optim.FusedSGD(mlp_params)

    embedding_params = [{
        'params':
        list(model.bottom_model.embeddings.parameters()),
        'lr':
        embedding_model_parallel_lr
    }]
    embedding_lrs = [embedding_model_parallel_lr]

    if FLAGS.Adam_embedding_optimizer:
        embedding_optimizer = torch.optim.SparseAdam(embedding_params)
    else:
        embedding_optimizer = torch.optim.SGD(embedding_params)

    checkpoint_writer = make_distributed_checkpoint_writer(
        device_mapping=device_mapping,
        rank=rank,
        is_main_process=is_main_process(),
        config=FLAGS.flag_values_dict())

    checkpoint_loader = make_distributed_checkpoint_loader(
        device_mapping=device_mapping, rank=rank)

    if FLAGS.load_checkpoint_path:
        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
        model.to(device)

    if FLAGS.amp:
        (model.top_model,
         model.bottom_model.mlp), mlp_optimizer = amp.initialize(
             [model.top_model, model.bottom_model.mlp],
             mlp_optimizer,
             opt_level="O2",
             loss_scale=1)

    if use_gpu:
        model.top_model = parallel.DistributedDataParallel(model.top_model)
    else:  # Use other backend for CPU
        model.top_model = torch.nn.parallel.DistributedDataParallel(
            model.top_model)

    if FLAGS.mode == 'test':
        auc = dist_evaluate(model, data_loader_test)

        results = {'auc': auc}
        dllogger.log(data=results, step=tuple())

        if auc is not None:
            print(f"Finished testing. Test auc {auc:.4f}")
        return

    if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process(
    ):
        logging.warning(
            "Saving checkpoint without --bottom_features_ordered flag will result in "
            "a device-order dependent model. Consider using --bottom_features_ordered "
            "if you plan to load the checkpoint in different device configurations."
        )

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    steps_per_epoch = len(data_loader_train)
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=device)
    moving_loss_stream = torch.cuda.Stream()

    lr_scheduler = utils.LearningRateScheduler(
        optimizers=[mlp_optimizer, embedding_optimizer],
        base_lrs=[mlp_lrs, embedding_lrs],
        warmup_steps=FLAGS.warmup_steps,
        warmup_factor=FLAGS.warmup_factor,
        decay_start_step=FLAGS.decay_start_step,
        decay_steps=FLAGS.decay_steps,
        decay_power=FLAGS.decay_power,
        end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    data_stream = torch.cuda.Stream()
    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()
    stop_time = time()

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        batch_iter = prefetcher(iter(data_loader_train), data_stream)

        for step in range(len(data_loader_train)):
            timer.click()

            numerical_features, categorical_features, click = next(batch_iter)
            torch.cuda.synchronize()

            global_step = steps_per_epoch * epoch + step

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    f"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            lr_scheduler.step()

            if click.shape[0] != FLAGS.batch_size:  # last batch
                logging.error("The last batch with size %s is not supported",
                              click.shape[0])
            else:
                output = model(numerical_features, categorical_features,
                               batch_sizes_per_gpu).squeeze()

                loss = loss_fn(
                    output, click[batch_indices[rank]:batch_indices[rank + 1]])

                if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer:
                    model.zero_grad()
                else:
                    # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
                    for param_group in itertools.chain(
                            embedding_optimizer.param_groups,
                            mlp_optimizer.param_groups):
                        for param in param_group['params']:
                            param.grad = None

                if FLAGS.amp:
                    loss *= FLAGS.loss_scale
                    with amp.scale_loss(loss, mlp_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if FLAGS.Adam_MLP_optimizer:
                    scale_MLP_gradients(mlp_optimizer, world_size)
                mlp_optimizer.step()

                if FLAGS.Adam_embedding_optimizer:
                    scale_embeddings_gradients(embedding_optimizer, world_size)
                embedding_optimizer.step()

                moving_loss_stream.wait_stream(torch.cuda.current_stream())
                with torch.cuda.stream(moving_loss_stream):
                    moving_loss += loss

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if step == 0:
                print(f"Started epoch {epoch}...")
            elif step % print_freq == 0:
                torch.cuda.current_stream().wait_stream(moving_loss_stream)
                # Averaging across a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[0]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                else:
                    metric_logger.update(
                        step_time=timer.measured,
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[0]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                stop_time = time()

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

                with torch.cuda.stream(moving_loss_stream):
                    moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                auc = dist_evaluate(model, data_loader_test)

                if auc is None:
                    continue

                print(f"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    run_time_s = int(stop_time - start_time)
                    print(
                        f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. "
                        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    sys.exit()

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(
            f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. "
            f"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s."
        )

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    if FLAGS.save_checkpoint_path:
        checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path,
                                          epoch, step)

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    dllogger.log(data=results, step=tuple())
示例#6
0
    random.seed(hparams.seed)
    reader = Reader(hparams)
    start = time.time()
    logger.info("Loading data...")
    reader.load_data("train")
    end = time.time()
    logger.info("Loaded. {} secs".format(end - start))

    model = DST(hparams).cuda()
    optimizer = Adam(model.parameters(), hparams.lr)
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level="O1",
                                      verbosity=0)
    model = parallel.DistributedDataParallel(model)

    # load saved model, optimizer
    if hparams.save_path is not None:
        load(model, optimizer, hparams.save_path)
        torch.distributed.barrier()

    train.max_iter = len(list(reader.make_batch(reader.train)))
    validate.max_iter = len(list(reader.make_batch(reader.dev)))
    train.warmup_steps = train.max_iter * hparams.max_epochs * hparams.warmup_steps

    train.global_step = 0
    max_joint_acc = 0
    early_stop_count = hparams.early_stop_count

    for epoch in range(hparams.max_epochs):
示例#7
0
    def train(self, args):
        # Reset amp
        if args.use_apex:
            from apex import amp
            
            amp.init(False)

        # Get dataloaders
        train_dataloader = ds_utils.get_dataloader(args, 'train')
        if not args.skip_test:
            test_dataloader = ds_utils.get_dataloader(args, 'test')

        model = runner = self.runner

        if args.use_half:
            runner.half()

        # Initialize optimizers, schedulers and apex
        opts = runner.get_optimizers(args)

        # Load pre-trained params for optimizers and schedulers (if needed)
        if args.which_epoch != 'none' and not args.init_experiment_dir:
            for net_name, opt in opts.items():
                opt.load_state_dict(torch.load(self.checkpoints_dir / f'{args.which_epoch}_opt_{net_name}.pth', map_location='cpu'))

        if args.use_apex and args.num_gpus > 0 and args.num_gpus <= 8:
            # Enfornce apex mixed precision settings
            nets_list, opts_list = [], []
            for net_name in sorted(opts.keys()):
                nets_list.append(runner.nets[net_name])
                opts_list.append(opts[net_name])

            loss_scale = float(args.amp_loss_scale) if args.amp_loss_scale != 'dynamic' else args.amp_loss_scale

            nets_list, opts_list = amp.initialize(nets_list, opts_list, opt_level=args.amp_opt_level, num_losses=1, loss_scale=loss_scale)

            # Unpack opts_list into optimizers
            for net_name, net, opt in zip(sorted(opts.keys()), nets_list, opts_list):
                runner.nets[net_name] = net
                opts[net_name] = opt

            if args.which_epoch != 'none' and not args.init_experiment_dir and os.path.exists(self.checkpoints_dir / f'{args.which_epoch}_amp.pth'):
                amp.load_state_dict(torch.load(self.checkpoints_dir / f'{args.which_epoch}_amp.pth', map_location='cpu'))

        # Initialize apex distributed data parallel wrapper
        if args.num_gpus > 1 and args.num_gpus <= 8:
            from apex import parallel

            model = parallel.DistributedDataParallel(runner, delay_allreduce=True)

        epoch_start = 1 if args.which_epoch == 'none' else int(args.which_epoch) + 1

        # Initialize logging
        train_iter = epoch_start - 1

        if args.visual_freq != -1:
            train_iter /= args.visual_freq

        logger = Logger(args, self.experiment_dir)
        logger.set_num_iter(
            train_iter=train_iter, 
            test_iter=(epoch_start - 1) // args.test_freq)

        if args.debug and not args.use_apex:
            torch.autograd.set_detect_anomaly(True)

        total_iters = 1

        for epoch in range(epoch_start, args.num_epochs + 1):
            if args.rank == 0: 
                print('epoch %d' % epoch)

            # Train for one epoch
            model.train()
            time_start = time.time()

            # Shuffle the dataset before the epoch
            train_dataloader.dataset.shuffle()

            for i, data_dict in enumerate(train_dataloader, 1):               
                # Prepare input data
                if args.num_gpus > 0 and args.num_gpus > 0:
                    for key, value in data_dict.items():
                        data_dict[key] = value.cuda()

                # Convert inputs to FP16
                if args.use_half:
                    for key, value in data_dict.items():
                        data_dict[key] = value.half()

                output_logs = i == len(train_dataloader)

                if args.visual_freq != -1:
                    output_logs = not (total_iters % args.visual_freq)

                output_visuals = output_logs and not args.no_disk_write_ops

                # Accumulate list of optimizers that will perform opt step
                for opt in opts.values():
                    opt.zero_grad()

                # Perform a forward pass
                if not args.use_closure:
                    loss = model(data_dict)
                    closure = None

                if args.use_apex and args.num_gpus > 0 and args.num_gpus <= 8:
                    # Mixed precision requires a special wrapper for the loss
                    with amp.scale_loss(loss, opts.values()) as scaled_loss:
                        scaled_loss.backward()

                elif not args.use_closure:
                    loss.backward()

                else:
                    def closure():
                        loss = model(data_dict)
                        loss.backward()
                        return loss

                # Perform steps for all optimizers
                for opt in opts.values():
                    opt.step(closure)

                if output_logs:
                    logger.output_logs('train', runner.output_visuals(), runner.output_losses(), time.time() - time_start)

                    if args.debug:
                        break

                if args.visual_freq != -1:
                    total_iters += 1
                    total_iters %= args.visual_freq
            
            # Increment the epoch counter in the training dataset
            train_dataloader.dataset.epoch += 1

            # If testing is not required -- continue
            if epoch % args.test_freq:
                continue

            # If skip test flag is set -- only check if a checkpoint if required
            if not args.skip_test:
                # Calculate "standing" stats for the batch normalization
                if args.calc_stats:
                    runner.calculate_batchnorm_stats(train_dataloader, args.debug)

                # Test
                time_start = time.time()
                model.eval()

                for data_dict in test_dataloader:
                    # Prepare input data
                    if args.num_gpus > 0:
                        for key, value in data_dict.items():
                            data_dict[key] = value.cuda()

                    # Forward pass
                    with torch.no_grad():
                        model(data_dict)
                    
                    if args.debug:
                        break

            # Output logs
            logger.output_logs('test', runner.output_visuals(), runner.output_losses(), time.time() - time_start)
            
            # If creation of checkpoint is not required -- continue
            if epoch % args.checkpoint_freq and not args.debug:
                continue

            # Create or load a checkpoint
            if args.rank == 0  and not args.no_disk_write_ops:
                with torch.no_grad():
                    for net_name in runner.nets_names_to_train:
                        # Save a network
                        torch.save(runner.nets[net_name].state_dict(), self.checkpoints_dir / f'{epoch}_{net_name}.pth')

                        # Save an optimizer
                        torch.save(opts[net_name].state_dict(), self.checkpoints_dir / f'{epoch}_opt_{net_name}.pth')

                    # Save amp
                    if args.use_apex:
                        torch.save(amp.state_dict(), self.checkpoints_dir / f'{epoch}_amp.pth')

        return runner
示例#8
0
def main(argv):
    if FLAGS.seed is not None:
        torch.manual_seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)

    # Initialize distributed mode
    use_gpu = "cpu" not in FLAGS.device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend,
                                                       use_gpu=use_gpu)
    if world_size == 1:
        raise NotImplementedError(
            "This file is only for distributed training.")

    mlperf_logger.mlperf_submission_log('dlrm')
    mlperf_logger.log_event(key=mlperf_logger.constants.SEED, value=FLAGS.seed)
    mlperf_logger.log_event(key=mlperf_logger.constants.GLOBAL_BATCH_SIZE,
                            value=FLAGS.batch_size)

    # Only print cmd args on rank 0
    if rank == 0:
        print("Command line flags:")
        pprint(FLAGS.flag_values_dict())

    # Check arguments sanity
    if FLAGS.batch_size % world_size != 0:
        raise ValueError(
            F"Batch size {FLAGS.batch_size} is not divisible by world_size {world_size}."
        )
    if FLAGS.test_batch_size % world_size != 0:
        raise ValueError(
            F"Test batch size {FLAGS.test_batch_size} is not divisible by world_size {world_size}."
        )

    # Load config file, create sub config for each rank
    with open(FLAGS.model_config, "r") as f:
        config = json.loads(f.read())

    wolrd_categorical_feature_sizes = np.asarray(
        config.pop('categorical_feature_sizes'))
    device_mapping = dist_model.get_criteo_device_mapping(world_size)
    vectors_per_gpu = device_mapping['vectors_per_gpu']
    # Get sizes of embeddings each GPU is gonna create
    categorical_feature_sizes = wolrd_categorical_feature_sizes[
        device_mapping['embedding'][rank]].tolist()

    bottom_mlp_sizes = config.pop('bottom_mlp_sizes')
    if rank != device_mapping['bottom_mlp']:
        bottom_mlp_sizes = None

    model = dist_model.DistDlrm(
        categorical_feature_sizes=categorical_feature_sizes,
        bottom_mlp_sizes=bottom_mlp_sizes,
        world_num_categorical_features=len(wolrd_categorical_feature_sizes),
        **config,
        device=FLAGS.device,
        use_embedding_ext=FLAGS.use_embedding_ext)
    print(model)

    dist.setup_distributed_print(rank == 0)

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.fp16 else FLAGS.lr
    scaled_lrs = [scaled_lr / world_size, scaled_lr]

    embedding_optimizer = torch.optim.SGD([
        {
            'params': model.bottom_model.joint_embedding.parameters(),
            'lr': scaled_lrs[0]
        },
    ])
    mlp_optimizer = apex_optim.FusedSGD([{
        'params':
        model.bottom_model.bottom_mlp.parameters(),
        'lr':
        scaled_lrs[0]
    }, {
        'params': model.top_model.parameters(),
        'lr': scaled_lrs[1]
    }])

    if FLAGS.fp16:
        (model.top_model,
         model.bottom_model.bottom_mlp), mlp_optimizer = amp.initialize(
             [model.top_model, model.bottom_model.bottom_mlp],
             mlp_optimizer,
             opt_level="O2",
             loss_scale=1,
             cast_model_outputs=torch.float16)

    if use_gpu:
        model.top_model = parallel.DistributedDataParallel(model.top_model)
    else:  # Use other backend for CPU
        model.top_model = torch.nn.parallel.DistributedDataParallel(
            model.top_model)

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Too many arguments to pass for distributed training. Use plain train code here instead of
    # defining a train function

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.4f} ms'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=FLAGS.device)
    moving_loss_stream = torch.cuda.Stream()

    local_embedding_device_mapping = torch.tensor(
        device_mapping['embedding'][rank],
        device=FLAGS.device,
        dtype=torch.long)

    # LR is logged twice for now because of a compliance checker bug
    mlperf_logger.log_event(key=mlperf_logger.constants.OPT_BASE_LR,
                            value=FLAGS.lr)
    mlperf_logger.log_event(key=mlperf_logger.constants.OPT_LR_WARMUP_STEPS,
                            value=FLAGS.warmup_steps)

    # use logging keys from the official HP table and not from the logging library
    mlperf_logger.log_event(key='sgd_opt_base_learning_rate', value=FLAGS.lr)
    mlperf_logger.log_event(key='lr_decay_start_steps',
                            value=FLAGS.decay_start_step)
    mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_steps',
                            value=FLAGS.decay_steps)
    mlperf_logger.log_event(key='sgd_opt_learning_rate_decay_poly_power',
                            value=FLAGS.decay_power)

    lr_scheduler = utils.LearningRateScheduler(
        optimizers=[mlp_optimizer, embedding_optimizer],
        base_lrs=[scaled_lrs, [scaled_lrs[0]]],
        warmup_steps=FLAGS.warmup_steps,
        warmup_factor=FLAGS.warmup_factor,
        decay_start_step=FLAGS.decay_start_step,
        decay_steps=FLAGS.decay_steps,
        decay_power=FLAGS.decay_power,
        end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    data_stream = torch.cuda.Stream()
    eval_data_cache = [] if FLAGS.cache_eval_data else None

    start_time = time()
    stop_time = time()

    print("Creating data loaders")
    dist_dataset_args = {
        "numerical_features": rank == 0,
        "categorical_features": device_mapping['embedding'][rank]
    }

    mlperf_logger.barrier()
    mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP)
    mlperf_logger.barrier()
    mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START)
    mlperf_logger.barrier()

    data_loader_train, data_loader_test = dataset.get_data_loader(
        FLAGS.dataset,
        FLAGS.batch_size,
        FLAGS.test_batch_size,
        FLAGS.device,
        dataset_type=FLAGS.dataset_type,
        shuffle=FLAGS.shuffle,
        **dist_dataset_args)

    steps_per_epoch = len(data_loader_train)

    # Default 20 tests per epoch
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch // 20

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        mlperf_logger.barrier()
        mlperf_logger.log_start(key=mlperf_logger.constants.BLOCK_START,
                                metadata={
                                    mlperf_logger.constants.FIRST_EPOCH_NUM:
                                    epoch + 1,
                                    mlperf_logger.constants.EPOCH_COUNT:
                                    1
                                })
        mlperf_logger.barrier()
        mlperf_logger.log_start(
            key=mlperf_logger.constants.EPOCH_START,
            metadata={mlperf_logger.constants.EPOCH_NUM: epoch + 1})

        if FLAGS.profile_steps is not None:
            torch.cuda.profiler.start()
        for step, (numerical_features, categorical_features,
                   click) in enumerate(
                       dataset.prefetcher(iter(data_loader_train),
                                          data_stream)):
            torch.cuda.current_stream().wait_stream(data_stream)

            global_step = steps_per_epoch * epoch + step
            lr_scheduler.step()

            # Slice out categorical features if not using the "dist" dataset
            if FLAGS.dataset_type != "dist":
                categorical_features = categorical_features[:,
                                                            local_embedding_device_mapping]

            if FLAGS.fp16 and categorical_features is not None:
                numerical_features = numerical_features.to(torch.float16)

            last_batch_size = None
            if click.shape[0] != FLAGS.batch_size:  # last batch
                last_batch_size = click.shape[0]
                logging.debug("Pad the last batch of size %d to %d",
                              last_batch_size, FLAGS.batch_size)
                padding_size = FLAGS.batch_size - last_batch_size
                padding_numiercal = torch.empty(
                    padding_size,
                    numerical_features.shape[1],
                    device=numerical_features.device,
                    dtype=numerical_features.dtype)
                numerical_features = torch.cat(
                    (numerical_features, padding_numiercal), dim=0)
                if categorical_features is not None:
                    padding_categorical = torch.ones(
                        padding_size,
                        categorical_features.shape[1],
                        device=categorical_features.device,
                        dtype=categorical_features.dtype)
                    categorical_features = torch.cat(
                        (categorical_features, padding_categorical), dim=0)
                padding_click = torch.empty(padding_size,
                                            device=click.device,
                                            dtype=click.dtype)
                click = torch.cat((click, padding_click))

            bottom_out = model.bottom_model(numerical_features,
                                            categorical_features)

            batch_size_per_gpu = FLAGS.batch_size // world_size
            from_bottom = dist_model.bottom_to_top(bottom_out,
                                                   batch_size_per_gpu,
                                                   config['embedding_dim'],
                                                   vectors_per_gpu)

            if last_batch_size is not None:
                partial_rank = math.ceil(last_batch_size / batch_size_per_gpu)
                if rank == partial_rank:
                    top_out = model.top_model(
                        from_bottom[:last_batch_size %
                                    batch_size_per_gpu]).squeeze().float()
                    loss = loss_fn(
                        top_out, click[rank * batch_size_per_gpu:(rank + 1) *
                                       batch_size_per_gpu][:last_batch_size %
                                                           batch_size_per_gpu])
                elif rank < partial_rank:
                    loss = loss_fn(
                        model.top_model(from_bottom).squeeze().float(),
                        click[rank * batch_size_per_gpu:(rank + 1) *
                              batch_size_per_gpu])
                else:
                    # Back propgate nothing for padded samples
                    loss = 0. * model.top_model(
                        from_bottom).squeeze().float().mean()
            else:
                loss = loss_fn(
                    model.top_model(from_bottom).squeeze().float(),
                    click[rank * batch_size_per_gpu:(rank + 1) *
                          batch_size_per_gpu])

            # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
            for param_group in itertools.chain(
                    embedding_optimizer.param_groups,
                    mlp_optimizer.param_groups):
                for param in param_group['params']:
                    param.grad = None

            if FLAGS.fp16:
                loss *= FLAGS.loss_scale
                with amp.scale_loss(loss, mlp_optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            mlp_optimizer.step()
            embedding_optimizer.step()

            moving_loss_stream.wait_stream(torch.cuda.current_stream())
            with torch.cuda.stream(moving_loss_stream):
                moving_loss += loss
            if step == 0:
                print(F"Started epoch {epoch}...")
            elif step % print_freq == 0:
                torch.cuda.synchronize()
                # Averaging cross a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.
                metric_logger.update(step_time=(time() - stop_time) * 1000 /
                                     print_freq,
                                     loss=moving_loss.item() / print_freq /
                                     (FLAGS.loss_scale if FLAGS.fp16 else 1),
                                     lr=mlp_optimizer.param_groups[1]["lr"] *
                                     (FLAGS.loss_scale if FLAGS.fp16 else 1))
                stop_time = time()
                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.avg / 1000 *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    F"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )
                moving_loss = 0.
                with torch.cuda.stream(moving_loss_stream):
                    moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                mlperf_epoch_index = global_step / steps_per_epoch + 1

                mlperf_logger.barrier()
                mlperf_logger.log_start(key=mlperf_logger.constants.EVAL_START,
                                        metadata={
                                            mlperf_logger.constants.EPOCH_NUM:
                                            mlperf_epoch_index
                                        })
                auc = dist_evaluate(model, data_loader_test, eval_data_cache)
                mlperf_logger.log_event(
                    key=mlperf_logger.constants.EVAL_ACCURACY,
                    value=float(auc),
                    metadata={
                        mlperf_logger.constants.EPOCH_NUM: mlperf_epoch_index
                    })
                print(F"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()
                mlperf_logger.barrier()
                mlperf_logger.log_end(key=mlperf_logger.constants.EVAL_STOP,
                                      metadata={
                                          mlperf_logger.constants.EPOCH_NUM:
                                          mlperf_epoch_index
                                      })

                if auc > FLAGS.auc_threshold:
                    mlperf_logger.barrier()
                    mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP,
                                          metadata={
                                              mlperf_logger.constants.STATUS:
                                              mlperf_logger.constants.SUCCESS
                                          })

                    mlperf_logger.barrier()
                    mlperf_logger.log_end(
                        key=mlperf_logger.constants.EPOCH_STOP,
                        metadata={
                            mlperf_logger.constants.EPOCH_NUM: epoch + 1
                        })
                    mlperf_logger.barrier()
                    mlperf_logger.log_end(
                        key=mlperf_logger.constants.BLOCK_STOP,
                        metadata={
                            mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1
                        })

                    run_time_s = int(stop_time - start_time)
                    print(
                        F"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        F"{global_step/steps_per_epoch:.2f} in {run_time_s}s. "
                        F"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    return

            if FLAGS.profile_steps is not None and global_step == FLAGS.profile_steps:
                torch.cuda.profiler.stop()
                logging.warning("Profile run, stopped at step %d.",
                                global_step)
                return

        mlperf_logger.barrier()
        mlperf_logger.log_end(
            key=mlperf_logger.constants.EPOCH_STOP,
            metadata={mlperf_logger.constants.EPOCH_NUM: epoch + 1})
        mlperf_logger.barrier()
        mlperf_logger.log_end(
            key=mlperf_logger.constants.BLOCK_STOP,
            metadata={mlperf_logger.constants.FIRST_EPOCH_NUM: epoch + 1})

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(
            F"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. "
            F"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s."
        )

    mlperf_logger.barrier()
    mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP,
                          metadata={
                              mlperf_logger.constants.STATUS:
                              mlperf_logger.constants.ABORTED
                          })
示例#9
0
def train(cfg):
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        torch.distributed.init_process_group(backend="nccl",
                                             world_size=num_gpus)

    # set logger
    log_dir = os.path.join("logs/", cfg.source_dataset, cfg.prefix)
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir, exist_ok=True)

    logging.basicConfig(format="%(asctime)s %(message)s",
                        filename=log_dir + "/" + "log.txt",
                        filemode="a")

    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)
    logger.addHandler(stream_handler)

    # writer = SummaryWriter(log_dir, purge_step=0)

    if dist.is_initialized() and dist.get_rank() != 0:

        logger = writer = None
    else:
        logger.info(pprint.pformat(cfg))

    # training data loader
    if not cfg.joint_training:  # single domain
        train_loader = get_train_loader(root=os.path.join(
            cfg.source.root, cfg.source.train),
                                        batch_size=cfg.batch_size,
                                        image_size=cfg.image_size,
                                        random_flip=cfg.random_flip,
                                        random_crop=cfg.random_crop,
                                        random_erase=cfg.random_erase,
                                        color_jitter=cfg.color_jitter,
                                        padding=cfg.padding,
                                        num_workers=4)
    else:  # cross domain
        source_root = os.path.join(cfg.source.root, cfg.source.train)
        target_root = os.path.join(cfg.target.root, cfg.target.train)

        train_loader = get_cross_domain_train_loader(
            source_root=source_root,
            target_root=target_root,
            batch_size=cfg.batch_size,
            random_flip=cfg.random_flip,
            random_crop=cfg.random_crop,
            random_erase=cfg.random_erase,
            color_jitter=cfg.color_jitter,
            padding=cfg.padding,
            image_size=cfg.image_size,
            num_workers=8)

    # evaluation data loader
    query_loader = None
    gallery_loader = None
    if cfg.eval_interval > 0:
        query_loader = get_test_loader(root=os.path.join(
            cfg.target.root, cfg.target.query),
                                       batch_size=512,
                                       image_size=cfg.image_size,
                                       num_workers=4)

        gallery_loader = get_test_loader(root=os.path.join(
            cfg.target.root, cfg.target.gallery),
                                         batch_size=512,
                                         image_size=cfg.image_size,
                                         num_workers=4)

    # model
    num_classes = cfg.source.num_id
    num_cam = cfg.source.num_cam + cfg.target.num_cam
    cam_ids = train_loader.dataset.target_dataset.cam_ids if cfg.joint_training else train_loader.dataset.cam_ids
    num_instances = len(
        train_loader.dataset.target_dataset) if cfg.joint_training else None

    model = Model(num_classes=num_classes,
                  drop_last_stride=cfg.drop_last_stride,
                  joint_training=cfg.joint_training,
                  num_instances=num_instances,
                  cam_ids=cam_ids,
                  num_cam=num_cam,
                  neighbor_mode=cfg.neighbor_mode,
                  neighbor_eps=cfg.neighbor_eps,
                  scale=cfg.scale,
                  mix=cfg.mix,
                  alpha=cfg.alpha)

    model.cuda()

    # optimizer
    ft_params = model.backbone.parameters()
    new_params = [
        param for name, param in model.named_parameters()
        if not name.startswith("backbone.")
    ]
    param_groups = [{
        'params': ft_params,
        'lr': cfg.ft_lr
    }, {
        'params': new_params,
        'lr': cfg.new_params_lr
    }]

    optimizer = optim.SGD(param_groups, momentum=0.9, weight_decay=cfg.wd)

    # convert model for mixed precision distributed training

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      enabled=cfg.fp16,
                                      opt_level="O2")
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                  milestones=cfg.lr_step,
                                                  gamma=0.1)

    if dist.is_initialized():
        model = parallel.DistributedDataParallel(model, delay_allreduce=True)

    # engine
    checkpoint_dir = os.path.join("checkpoints", cfg.source_dataset,
                                  cfg.prefix)
    engine = get_trainer(
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        logger=logger,
        # writer=writer,
        non_blocking=True,
        log_period=cfg.log_period,
        save_interval=10,
        save_dir=checkpoint_dir,
        prefix=cfg.prefix,
        eval_interval=cfg.eval_interval,
        query_loader=query_loader,
        gallery_loader=gallery_loader)

    # training
    engine.run(train_loader, max_epochs=cfg.num_epoch)

    if dist.is_initialized():
        dist.destroy_process_group()
示例#10
0
def main():

    distributed.init_process_group(backend='nccl')

    with open(args.config) as file:
        config = Dict(json.load(file))

    config.update(vars(args))
    config.update(
        dict(world_size=distributed.get_world_size(),
             global_rank=distributed.get_rank(),
             device_count=torch.cuda.device_count()))
    config = apply(Dict, config)
    print(f'config: {config}')

    torch.manual_seed(0)
    torch.cuda.set_device(config.local_rank)

    generator = models.Generator(linear_params=[
        Dict(in_features=44, out_features=128),
        *[Dict(in_features=128, out_features=128)] * 8,
        Dict(in_features=128, out_features=1)
    ]).cuda()
    discriminator = models.Discriminator(
        conv_params=[
            Dict(in_channels=1,
                 out_channels=32,
                 kernel_size=3,
                 stride=2,
                 bias=False),
            Dict(in_channels=32,
                 out_channels=64,
                 kernel_size=3,
                 stride=2,
                 bias=False)
        ],
        linear_param=Dict(in_features=64, out_features=11)).cuda()

    generator_optimizer = torch.optim.Adam(params=generator.parameters(),
                                           lr=config.generator_lr,
                                           betas=(config.generator_beta1,
                                                  config.generator_beta2))
    discriminator_optimizer = torch.optim.Adam(
        params=discriminator.parameters(),
        lr=config.discriminator_lr,
        betas=(config.discriminator_beta1, config.discriminator_beta2))

    [generator, discriminator
     ], [generator_optimizer, discriminator_optimizer] = amp.initialize(
         models=[generator, discriminator],
         optimizers=[generator_optimizer, discriminator_optimizer],
         opt_level=config.opt_level)

    generator = parallel.DistributedDataParallel(generator,
                                                 delay_allreduce=True)
    discriminator = parallel.DistributedDataParallel(discriminator,
                                                     delay_allreduce=True)

    epoch = 0
    global_step = 0
    if config.checkpoint:
        checkpoint = Dict(torch.load(config.checkpoint),
                          map_location=lambda storage, location: storage.cuda(
                              config.local_rank))
        generator.load_state_dict(checkpoint.generator_state_dict)
        generator_optimizer.load_state_dict(
            checkpoint.generator_optimizer_state_dict)
        discriminator.load_state_dict(checkpoint.discriminator_state_dict)
        discriminator_optimizer.load_state_dict(
            checkpoint.discriminator_optimizer_state_dict)
        epoch = checkpoint.last_epoch + 1
        global_step = checkpoint.last_global_step + 1

    if config.global_rank == 0:
        os.makedirs(config.checkpoint_directory, exist_ok=True)
        os.makedirs(config.event_directory, exist_ok=True)
        summary_writer = SummaryWriter(config.event_directory)

    if config.train:

        dataset = datasets.MNIST(root='mnist',
                                 train=True,
                                 download=True,
                                 transform=transforms.Compose(
                                     [transforms.ToTensor()]))

        distributed_sampler = utils.data.distributed.DistributedSampler(
            dataset)

        data_loader = utils.data.DataLoader(dataset=dataset,
                                            batch_size=config.local_batch_size,
                                            num_workers=config.num_workers,
                                            sampler=distributed_sampler,
                                            pin_memory=True,
                                            drop_last=True)

        for epoch in range(epoch, config.num_epochs):

            discriminator.train()
            distributed_sampler.set_epoch(epoch)

            for step, (real_images, real_labels) in enumerate(data_loader):

                real_images = real_images.cuda()
                real_labels = real_labels.cuda()

                labels = nn.functional.embedding(real_labels,
                                                 torch.eye(10, device='cuda'))
                labels = labels.repeat(1, config.image_size**2).reshape(-1, 10)

                latents = torch.randn(config.local_batch_size,
                                      32,
                                      device='cuda')
                latents = latents.repeat(1,
                                         config.image_size**2).reshape(-1, 32)

                y = torch.linspace(-1, 1, config.image_size, device='cuda')
                x = torch.linspace(-1, 1, config.image_size, device='cuda')
                y, x = torch.meshgrid(y, x)
                positions = torch.stack((y.reshape(-1), x.reshape(-1)), dim=1)
                positions = positions.repeat(config.local_batch_size, 1)

                fake_images = generator(
                    torch.cat((labels, latents, positions), dim=1))
                fake_images = fake_images.reshape(-1, 1, config.image_size,
                                                  config.image_size)

                real_logits = discriminator(real_images, real_labels)
                real_adversarial_logits, real_classification_logits = torch.split(
                    real_logits, [1, 10], dim=1)

                fake_logits = discriminator(fake_images.detach(), real_labels)
                fake_adversarial_logits, fake_classification_logits = torch.split(
                    fake_logits, [1, 10], dim=1)

                discriminator_loss = torch.mean(
                    nn.functional.softplus(-real_adversarial_logits))
                discriminator_loss += torch.mean(
                    nn.functional.softplus(fake_adversarial_logits))
                discriminator_loss += nn.functional.cross_entropy(
                    real_classification_logits, real_labels)
                discriminator_loss += nn.functional.cross_entropy(
                    fake_classification_logits, real_labels)

                discriminator_optimizer.zero_grad()
                with amp.scale_loss(
                        discriminator_loss,
                        discriminator_optimizer) as scaled_discriminator_loss:
                    scaled_discriminator_loss.backward()
                discriminator_optimizer.step()

                fake_logits = discriminator(fake_images, real_labels)
                fake_adversarial_logits, fake_classification_logits = torch.split(
                    fake_logits, [1, 10], dim=1)

                generator_loss = torch.mean(
                    nn.functional.softplus(-fake_adversarial_logits))
                generator_loss += nn.functional.cross_entropy(
                    fake_classification_logits, real_labels)

                generator_optimizer.zero_grad()
                with amp.scale_loss(
                        generator_loss,
                        generator_optimizer) as scaled_generator_loss:
                    scaled_generator_loss.backward()
                generator_optimizer.step()

                global_step += 1

                if step % 100 == 0 and config.global_rank == 0:

                    summary_writer.add_images(tag='real_images',
                                              img_tensor=real_images.repeat(
                                                  1, 3, 1, 1),
                                              global_step=global_step)
                    summary_writer.add_images(tag='fake_images',
                                              img_tensor=fake_images.repeat(
                                                  1, 3, 1, 1),
                                              global_step=global_step)
                    summary_writer.add_scalars(
                        main_tag='training',
                        tag_scalar_dict=dict(
                            generator_loss=generator_loss,
                            discriminator_loss=discriminator_loss,
                            global_step=global_step))

                    print(
                        f'[training] epoch: {epoch} step: {step} generator_loss: {generator_loss:.4f} discriminator_loss: {discriminator_loss:.4f}'
                    )

            torch.save(
                dict(
                    generator_state_dict=generator.state_dict(),
                    generator_optimizer_state_dict=generator_optimizer.
                    state_dict(),
                    discriminator_state_dict=discriminator.state_dict(),
                    discriminator_optimizer_state_dict=discriminator_optimizer.
                    state_dict(),
                    last_epoch=epoch,
                    last_global_step=global_step),
                f'{config.checkpoint_directory}/epoch_{epoch}')

    if config.generate:

        with torch.no_grad():

            labels = torch.multinomial(torch.ones(10, device='cuda'),
                                       num_samples=1)
            labels = nn.functional.embedding(labels,
                                             torch.eye(10, device='cuda'))
            labels = labels.repeat(1, config.image_size**2).reshape(-1, 10)

            latents = torch.randn(1, 32, device='cuda')
            latents = latents.repeat(1, config.image_size**2).reshape(-1, 32)

            y = torch.linspace(-1, 1, config.image_size, device='cuda')
            x = torch.linspace(-1, 1, config.image_size, device='cuda')
            y, x = torch.meshgrid(y, x)
            positions = torch.stack((y.reshape(-1), x.reshape(-1)), dim=1)
            positions = positions.repeat(1, 1)

            images = generator(torch.cat((labels, latents, positions), dim=1))
            images = images.reshape(-1, config.image_size, config.image_size)

        for i, image in enumerate(images.cpu().numpy()):
            io.imsave(f"{i}.jpg", image)

    summary_writer.close()
示例#11
0
        generator.load_state_dict(g_checkpoint['model_state_dict'],
                                  strict=False)
        discriminator.load_state_dict(d_checkpoint['model_state_dict'],
                                      strict=False)
        step = g_checkpoint['step']
        alpha = g_checkpoint['alpha']
        iteration = g_checkpoint['iteration']
        print('pre-trained model is loaded step:%d, iteration:%d' %
              (step, iteration))
    else:
        iteration = 0
        step = 1

    if args.distributed:
        generator = parallel.DistributedDataParallel(generator)
        discriminator = parallel.DistributedDataParallel(discriminator)
        vgg = parallel.DistributedDataParallel(vgg)
        face_align_net = parallel.DistributedDataParallel(
            torch.load('./checkpoints/compressed_model_011000.pth',
                       map_location=lambda storage, loc: storage.cuda(
                           args.local_rank)).to(device))
    else:
        if len(args.gpu_ids) > 1:
            generator = nn.DataParallel(generator, args.gpu_ids)
            discriminator = nn.DataParallel(discriminator, args.gpu_ids)
            vgg = nn.DataParallel(vgg, args.gpu_ids)
            face_align_net = nn.DataParallel(
                torch.load('./checkpoints/compressed_model_011000.pth').to(
                    device), args.gpu_ids)
        else:
        args.world_size = torch.distributed.get_world_size()
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=train_sampler)
    else:
        # load the dataset using structure DataLoader (part of torch.utils.data)
        dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    # instantiate Generator(nn.Module) and load in cpu/gpu
    generator = Generator().to(device)
    
    ## DEFINE CHECKPOINT
    # checkpoints are used during training to save a model (model parameters I suppose)
    # here we are only testing the pre-trained model, thus we load (torch.load) the model
    if args.distributed:
        g_checkpoint = torch.load(args.checkpoint_path, map_location = lambda storage, loc: storage.cuda(args.local_rank))
        generator = parallel.DistributedDataParallel(generator)
        generator = parallel.convert_syncbn_model(generator)
    else:
        g_checkpoint = torch.load(args.checkpoint_path)
    
    generator.load_state_dict(g_checkpoint['model_state_dict'], strict=False)
    step = g_checkpoint['step']
    alpha = g_checkpoint['alpha']
    iteration = g_checkpoint['iteration']
    print('pre-trained model is loaded step:%d, alpha:%d iteration:%d'%(step, alpha, iteration))
    MSE_Loss = nn.MSELoss()

    # notify all layers that you are in eval mode instead of training mode
    generator.eval()

    test(dataloader, generator, MSE_Loss, step, alpha)
示例#13
0
def main():
    global n_eval_epoch

    ## dataloader
    dataset_train = ImageNet(datapth, mode='train', cropsize=cropsize)
    sampler_train = torch.utils.data.distributed.DistributedSampler(
        dataset_train, shuffle=True)
    batch_sampler_train = torch.utils.data.sampler.BatchSampler(
        sampler_train, batchsize, drop_last=True
    )
    dl_train = DataLoader(
        dataset_train, batch_sampler=batch_sampler_train, num_workers=num_workers, pin_memory=True
    )
    dataset_eval = ImageNet(datapth, mode='val', cropsize=cropsize)
    sampler_val = torch.utils.data.distributed.DistributedSampler(
        dataset_eval, shuffle=False)
    batch_sampler_val = torch.utils.data.sampler.BatchSampler(
        sampler_val, batchsize * 2, drop_last=False
    )
    dl_eval = DataLoader(
        dataset_eval, batch_sampler=batch_sampler_val,
        num_workers=4, pin_memory=True
    )
    n_iters_per_epoch = len(dataset_train) // n_gpus // batchsize
    n_iters = n_epoches * n_iters_per_epoch


    ## model
    #  model = EfficientNet(model_type, n_classes)
    model = build_model(**model_args)


    ## sync bn
    #  if use_sync_bn: model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    init_model_weights(model)
    model.cuda()
    if use_sync_bn: model = parallel.convert_syncbn_model(model)
    crit = nn.CrossEntropyLoss()
    #  crit = LabelSmoothSoftmaxCEV3(lb_smooth)
    #  crit = SoftmaxCrossEntropyV2()

    ## optimizer
    optim = set_optimizer(model, lr, opt_wd, momentum, nesterov=nesterov)

    ## apex
    model, optim = amp.initialize(model, optim, opt_level=fp16_level)

    ## ema
    ema = EMA(model, ema_alpha)

    ## ddp training
    model = parallel.DistributedDataParallel(model, delay_allreduce=True)
    #  local_rank = dist.get_rank()
    #  model = nn.parallel.DistributedDataParallel(
    #      model, device_ids=[local_rank, ], output_device=local_rank
    #  )

    ## log meters
    time_meter = TimeMeter(n_iters)
    loss_meter = AvgMeter()
    logger = logging.getLogger()


    # for mixup
    label_encoder = OnehotEncoder(n_classes=model_args['n_classes'], lb_smooth=lb_smooth)
    mixuper = MixUper(mixup_alpha, mixup=mixup)

    ## train loop
    for e in range(n_epoches):
        sampler_train.set_epoch(e)
        model.train()
        for idx, (im, lb) in enumerate(dl_train):
            im, lb= im.cuda(), lb.cuda()
            #  lb = label_encoder(lb)
            #  im, lb = mixuper(im, lb)
            optim.zero_grad()
            logits = model(im)
            loss = crit(logits, lb) #+ cal_l2_loss(model, weight_decay)
            #  loss.backward()
            with amp.scale_loss(loss, optim) as scaled_loss:
                scaled_loss.backward()
            optim.step()
            torch.cuda.synchronize()
            ema.update_params()
            time_meter.update()
            loss_meter.update(loss.item())
            if (idx + 1) % 200 == 0:
                t_intv, eta = time_meter.get()
                lr_log = scheduler.get_lr_ratio() * lr
                msg = 'epoch: {}, iter: {}, lr: {:.4f}, loss: {:.4f}, time: {:.2f}, eta: {}'.format(
                    e + 1, idx + 1, lr_log, loss_meter.get()[0], t_intv, eta)
                logger.info(msg)
            scheduler.step()
        torch.cuda.empty_cache()
        if (e + 1) % n_eval_epoch == 0:
            if e > 50: n_eval_epoch = 5
            logger.info('evaluating...')
            acc_1, acc_5, acc_1_ema, acc_5_ema = evaluate(ema, dl_eval)
            msg = 'epoch: {}, naive_acc1: {:.4}, naive_acc5: {:.4}, ema_acc1: {:.4}, ema_acc5: {:.4}'.format(e + 1, acc_1, acc_5, acc_1_ema, acc_5_ema)
            logger.info(msg)
    if dist.is_initialized() and dist.get_rank() == 0:
        torch.save(model.module.state_dict(), './res/model_final.pth')
        torch.save(ema.ema_model.state_dict(), './res/model_final_ema.pth')
示例#14
0
def main():

    with open(args.config) as file:
        config = Dict(json.load(file))

    distributed.init_process_group(backend='nccl')
    world_size = distributed.get_world_size()
    global_rank = distributed.get_rank()
    device_count = torch.cuda.device_count()
    local_rank = args.local_rank
    torch.cuda.set_device(local_rank)
    print(
        f'Enabled distributed training. (global_rank: {global_rank}/{world_size}, local_rank: {local_rank}/{device_count})'
    )

    torch.manual_seed(0)
    model = models.resnet50()
    model.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
    model = model.cuda()

    config.global_batch_size = config.local_batch_size * world_size
    config.lr = config.base_lr * config.global_batch_size / 256

    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=config.lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=config.opt_level)
    model = parallel.DistributedDataParallel(model, delay_allreduce=True)

    last_epoch = -1
    if args.checkpoint:
        checkpoint = Dict(
            torch.load(args.checkpoint),
            map_location=lambda storage, location: storage.cuda(local_rank))
        model.load_state_dict(checkpoint.model_state_dict)
        optimizer.load_state_dict(checkpoint.optimizer_state_dict)
        last_epoch = checkpoint.last_epoch

    criterion = nn.CrossEntropyLoss(reduction='mean').cuda()

    scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                               milestones=config.lr_milestones,
                                               gamma=config.lr_gamma,
                                               last_epoch=last_epoch)

    summary_writer = SummaryWriter(config.event_directory)

    if args.training:

        os.makedirs(config.checkpoint_directory, exist_ok=True)
        os.makedirs(config.event_directory, exist_ok=True)

        # NOTE: When partition for distributed training executed?
        # NOTE: Should random seed be the same in the same node?
        train_pipeline = TrainPipeline(root=config.train_root,
                                       batch_size=config.local_batch_size,
                                       num_threads=config.num_workers,
                                       device_id=local_rank,
                                       num_shards=world_size,
                                       shard_id=global_rank,
                                       image_size=224)
        train_pipeline.build()

        # NOTE: What's `epoch_size`?
        # NOTE: Is that len(dataset) ?
        train_data_loader = pytorch.DALIClassificationIterator(
            pipelines=train_pipeline,
            size=list(train_pipeline.epoch_size().values())[0] // world_size,
            auto_reset=True,
            stop_at_epoch=True)

        val_pipeline = ValPipeline(root=config.val_root,
                                   batch_size=config.local_batch_size,
                                   num_threads=config.num_workers,
                                   device_id=local_rank,
                                   num_shards=world_size,
                                   shard_id=global_rank,
                                   image_size=224)
        val_pipeline.build()

        val_data_loader = pytorch.DALIClassificationIterator(
            pipelines=val_pipeline,
            size=list(val_pipeline.epoch_size().values())[0] // world_size,
            auto_reset=True,
            stop_at_epoch=True)

        for epoch in range(last_epoch + 1, config.num_epochs):

            model.train()
            scheduler.step()

            for step, data in enumerate(train_data_loader):

                images = data[0]["data"]
                labels = data[0]["label"]

                images = images.cuda()
                labels = labels.cuda()
                labels = labels.squeeze().long()

                logits = model(images)
                loss = criterion(logits, labels)

                optimizer.zero_grad()
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()

                if global_rank == 0:

                    summary_writer.add_scalars(main_tag='training',
                                               tag_scalar_dict=dict(loss=loss))

                    print(
                        f'[training] epoch: {epoch} step: {step} loss: {loss}')

            torch.save(
                dict(model_state_dict=model.state_dict(),
                     optimizer_state_dict=optimizer.state_dict(),
                     last_epoch=epoch),
                f'{config.checkpoint_directory}/epoch_{epoch}')

            model.eval()

            total_steps = 0
            total_loss = 0
            total_accurtacy = 0

            with torch.no_grad():

                for step, data in enumerate(val_data_loader):

                    images = data[0]["data"]
                    labels = data[0]["label"]

                    images = images.cuda()
                    labels = labels.cuda()
                    labels = labels.squeeze().long()

                    logits = model(images)
                    loss = criterion(logits, labels) / world_size
                    distributed.all_reduce(loss)

                    predictions = logits.topk(1)[1].squeeze()
                    accuracy = torch.mean(
                        (predictions == labels).float()) / world_size
                    distributed.all_reduce(accuracy)

                    total_steps += 1
                    total_loss += loss
                    total_accurtacy += accuracy

                loss = total_loss / total_steps
                accuracy = total_accurtacy / total_steps

            if global_rank == 0:

                summary_writer.add_scalars(main_tag='validation',
                                           tag_scalar_dict=dict(
                                               loss=loss, accuracy=accuracy))

                print(
                    f'[validation] epoch: {epoch} loss: {loss} accuracy: {accuracy}'
                )

    if args.evaluation:

        test_pipeline = ValPipeline(root=config.val_root,
                                    batch_size=config.local_batch_size,
                                    num_threads=config.num_workers,
                                    device_id=local_rank,
                                    num_shards=world_size,
                                    shard_id=global_rank,
                                    image_size=224)
        test_pipeline.build()

        test_data_loader = pytorch.DALIClassificationIterator(
            pipelines=test_pipeline,
            size=list(test_pipeline.epoch_size().values())[0] // world_size,
            auto_reset=True,
            stop_at_epoch=True)

        model.eval()

        total_steps = 0
        total_loss = 0
        total_accurtacy = 0

        with torch.no_grad():

            for step, data in enumerate(val_data_loader):

                images = data[0]["data"]
                labels = data[0]["label"]

                images = images.cuda()
                labels = labels.cuda()
                labels = labels.squeeze().long()

                logits = model(images)
                loss = criterion(logits, labels) / world_size
                distributed.all_reduce(loss)

                predictions = logits.topk(1)[1].squeeze()
                accuracy = torch.mean(
                    (predictions == labels).float()) / world_size
                distributed.all_reduce(accuracy)

                total_steps += 1
                total_loss += loss
                total_accurtacy += accuracy

            loss = total_loss / total_steps
            accuracy = total_accurtacy / total_steps

        if global_rank == 0:

            summary_writer.add_scalars(main_tag='validation',
                                       tag_scalar_dict=dict(loss=loss,
                                                            accuracy=accuracy))

            print(
                f'[evaluation] epoch: {last_epoch} loss: {loss} accuracy: {accuracy}'
            )

    summary_writer.close()