Пример #1
0
def main(opts):
    # ===== Setup distributed =====
    distributed.init_process_group(backend='nccl', init_method='env://')
    if opts.device is not None:
        device_id = opts.device
    else:
        device_id = opts.local_rank
    device = torch.device(device_id)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    if opts.device is not None:
        torch.cuda.set_device(opts.device)
    else:
        torch.cuda.set_device(device_id)

    # ===== Initialize logging =====
    logdir_full = f"{opts.logdir}/{opts.dataset}/{opts.name}/"
    if rank == 0:
        logger = Logger(logdir_full,
                        rank=rank,
                        debug=opts.debug,
                        summary=opts.visualize)
    else:
        logger = Logger(logdir_full,
                        rank=rank,
                        debug=opts.debug,
                        summary=False)

    logger.print(f"Device: {device}")

    checkpoint_path = f"checkpoints/{opts.dataset}/{opts.name}.pth"
    os.makedirs(f"checkpoints/{opts.dataset}", exist_ok=True)

    # ===== Setup random seed to reproducibility =====
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # ===== Set up dataset =====
    train_dst, val_dst = get_dataset(opts, train=True)

    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   sampler=DistributedSampler(
                                       train_dst,
                                       num_replicas=world_size,
                                       rank=rank),
                                   num_workers=opts.num_workers,
                                   drop_last=True,
                                   pin_memory=True)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.batch_size,
                                 sampler=DistributedSampler(
                                     val_dst,
                                     num_replicas=world_size,
                                     rank=rank),
                                 num_workers=opts.num_workers)
    logger.info(f"Dataset: {opts.dataset}, Train set: {len(train_dst)}, "
                f"Val set: {len(val_dst)}, n_classes {opts.num_classes}")
    logger.info(f"Total batch size is {opts.batch_size * world_size}")
    # This is necessary for computing the scheduler decay
    opts.max_iter = opts.max_iter = opts.epochs * len(train_loader)

    # ===== Set up model and ckpt =====
    model = Trainer(device, logger, opts)
    model.distribute()

    cur_epoch = 0
    if opts.continue_ckpt:
        opts.ckpt = checkpoint_path
    if opts.ckpt is not None:
        assert os.path.isfile(
            opts.ckpt), "Error, ckpt not found. Check the correct directory"
        checkpoint = torch.load(opts.ckpt, map_location="cpu")
        cur_epoch = checkpoint["epoch"] + 1
        model.load_state_dict(checkpoint["model_state"])
        logger.info("[!] Model restored from %s" % opts.ckpt)
        del checkpoint
    else:
        logger.info("[!] Train from scratch")

    # ===== Train procedure =====
    # print opts before starting training to log all parameters
    logger.add_table("Opts", vars(opts))

    # uncomment if you want qualitative on val
    # if rank == 0 and opts.sample_num > 0:
    #     sample_ids = np.random.choice(len(val_loader), opts.sample_num, replace=False)  # sample idxs for visualization
    #     logger.info(f"The samples id are {sample_ids}")
    # else:
    #     sample_ids = None

    label2color = utils.Label2Color(cmap=utils.color_map(
        opts.dataset))  # convert labels to images
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225
                                    ])  # de-normalization for original images

    train_metrics = StreamSegMetrics(opts.num_classes)
    val_metrics = StreamSegMetrics(opts.num_classes)
    results = {}

    # check if random is equal here.
    logger.print(torch.randint(0, 100, (1, 1)))

    while cur_epoch < opts.epochs and not opts.test:
        # =====  Train  =====
        start = time.time()
        epoch_loss = model.train(cur_epoch=cur_epoch,
                                 train_loader=train_loader,
                                 metrics=train_metrics,
                                 print_int=opts.print_interval)
        train_score = train_metrics.get_results()
        end = time.time()

        len_ep = int(end - start)
        logger.info(
            f"End of Epoch {cur_epoch}/{opts.epochs}, Average Loss={epoch_loss[0] + epoch_loss[1]:.4f}, "
            f"Class Loss={epoch_loss[0]:.4f}, Reg Loss={epoch_loss[1]}\n"
            f"Train_Acc={train_score['Overall Acc']:.4f}, Train_Iou={train_score['Mean IoU']:.4f} "
            f"\n -- time: {len_ep // 60}:{len_ep % 60} -- ")
        logger.info(
            f"I will finish in {len_ep * (opts.epochs - cur_epoch) // 60} minutes"
        )

        logger.add_scalar("E-Loss", epoch_loss[0] + epoch_loss[1], cur_epoch)
        # logger.add_scalar("E-Loss-reg", epoch_loss[1], cur_epoch)
        # logger.add_scalar("E-Loss-cls", epoch_loss[0], cur_epoch)

        # =====  Validation  =====
        if (cur_epoch + 1) % opts.val_interval == 0:
            logger.info("validate on val set...")
            val_loss, _ = model.validate(loader=val_loader,
                                         metrics=val_metrics,
                                         ret_samples_ids=None)
            val_score = val_metrics.get_results()

            logger.print("Done validation")
            logger.info(
                f"End of Validation {cur_epoch}/{opts.epochs}, Validation Loss={val_loss}"
            )

            log_val(logger, val_metrics, val_score, val_loss, cur_epoch)

            # keep the metric to print them at the end of training
            results["V-IoU"] = val_score['Class IoU']
            results["V-Acc"] = val_score['Class Acc']

        # =====  Save Model  =====
        if rank == 0:
            if not opts.debug:
                save_ckpt(checkpoint_path, model, cur_epoch)
                logger.info("[!] Checkpoint saved.")

        cur_epoch += 1

    torch.distributed.barrier()

    # ==== TESTING =====
    logger.info("*** Test the model on all seen classes...")
    # make data loader
    test_dst = get_dataset(opts, train=False)
    test_loader = data.DataLoader(test_dst,
                                  batch_size=opts.batch_size_test,
                                  sampler=DistributedSampler(
                                      test_dst,
                                      num_replicas=world_size,
                                      rank=rank),
                                  num_workers=opts.num_workers)

    if rank == 0 and opts.sample_num > 0:
        sample_ids = np.random.choice(len(test_loader),
                                      opts.sample_num,
                                      replace=False)  # sample idxs for visual.
        logger.info(f"The samples id are {sample_ids}")
    else:
        sample_ids = None

    val_loss, ret_samples = model.validate(loader=test_loader,
                                           metrics=val_metrics,
                                           ret_samples_ids=sample_ids)
    val_score = val_metrics.get_results()
    conf_matrixes = val_metrics.get_conf_matrixes()
    logger.print("Done test on all")
    logger.info(f"*** End of Test on all, Total Loss={val_loss}")

    logger.info(val_metrics.to_str(val_score))
    log_samples(logger, ret_samples, denorm, label2color, 0)

    logger.add_figure("Test_Confusion_Matrix_Recall",
                      conf_matrixes['Confusion Matrix'])
    logger.add_figure("Test_Confusion_Matrix_Precision",
                      conf_matrixes["Confusion Matrix Pred"])
    results["T-IoU"] = val_score['Class IoU']
    results["T-Acc"] = val_score['Class Acc']
    results["T-Prec"] = val_score['Class Prec']
    logger.add_results(results)
    logger.add_scalar("T_Overall_Acc", val_score['Overall Acc'])
    logger.add_scalar("T_MeanIoU", val_score['Mean IoU'])
    logger.add_scalar("T_MeanAcc", val_score['Mean Acc'])
    ret = val_score['Mean IoU']

    logger.close()
    return ret
Пример #2
0
def main():
    args = add_learner_params()
    if args.seed != -1:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
    args.root = 'logs/'+args.name+'/'

    if args.neptune:
        import neptune
        project = "arighosh/pretrain_noisy_label"
        neptune.init(project_qualified_name=project,
                     api_token=os.environ["NEPTUNE_API_TOKEN"])
        neptune.create_experiment(
            name=args.name, send_hardware_metrics=False, params=vars(args))
    fmt = {
        'train_time': '.3f',
        'val_time': '.3f',
        'train_epoch': '.1f',
        'lr': '.1e',
    }
    logger = Logger('logs', base=args.root, fmt=fmt)
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    if args.cuda:
        assert device.type == 'cuda', 'no gpu found!'

    with open(args.root+'config.yml', 'w') as outfile:
        yaml.dump(vars(args), outfile, default_flow_style=False)

    # create model
    model = models.REGISTERED_MODELS[args.problem](args, device=device)
    cur_iter = 0
    # Data loading code
    model.prepare_data()

    continue_training = cur_iter < args.iters
    data_time, it_time = 0, 0
    best_acc = 0.
    best_valid_acc, best_acc_with_valid = 0, 0

    while continue_training:
        train_loader, test_loader, valid_loader, meta_loader = model.dataloaders(
            iters=args.iters)
        train_logs = []
        model.train()
        start_time = time.time()
        for _, batch in enumerate(train_loader):
            cur_iter += 1
            batch = [x.to(device) for x in batch]
            data_time += time.time() - start_time
            logs = {}
            if args.problem not in {'finetune'}:
                meta_batch = next(iter(meta_loader))
                meta_batch = [x.to(device) for x in meta_batch]
                logs = model.train_step(batch, meta_batch, cur_iter)
            else:
                logs = model.train_step(batch, cur_iter)

            # save logs for the batch
            train_logs.append({k: utils.tonp(v) for k, v in logs.items()})
            if cur_iter % args.eval_freq == 0 or cur_iter >= args.iters:
                test_start_time = time.time()
                test_logs, valid_logs = [], []
                model.eval()
                with torch.no_grad():
                    for batch in test_loader:
                        batch = [x.to(device) for x in batch]
                        logs = model.test_step(batch)
                        test_logs.append(logs)
                    for batch in valid_loader:
                        batch = [x.to(device) for x in batch]
                        logs = model.test_step(batch)
                        valid_logs.append(logs)
                model.train()
                test_logs = utils.agg_all_metrics(test_logs)
                valid_logs = utils.agg_all_metrics(valid_logs)
                best_acc = max(best_acc, float(test_logs['acc']))
                test_logs['best_acc'] = best_acc
                if float(valid_logs['acc']) > best_valid_acc:
                    best_valid_acc = float(valid_logs['acc'])
                    best_acc_with_valid = float(test_logs['acc'])
                test_logs['best_acc_with_valid'] = best_acc_with_valid
                #

                if args.neptune:
                    for k, v in test_logs.items():
                        neptune.log_metric('test_'+k, float(v))
                    for k, v in valid_logs.items():
                        neptune.log_metric('valid_'+k, float(v))
                    test_it_time = time.time()-test_start_time
                    neptune.log_metric('test_it_time', test_it_time)
                    neptune.log_metric('test_cur_iter', cur_iter)
                logger.add_logs(cur_iter, test_logs, pref='test_')
            it_time += time.time() - start_time
            if (cur_iter % args.log_freq == 0 or cur_iter >= args.iters):
                train_logs = utils.agg_all_metrics(train_logs)
                if args.neptune:
                    for k, v in train_logs.items():
                        neptune.log_metric('train_'+k, float(v))
                    neptune.log_metric('train_it_time', it_time)
                    neptune.log_metric('train_data_time', data_time)
                    neptune.log_metric(
                        'train_lr', model.optimizer.param_groups[0]['lr'])
                    neptune.log_metric('train_cur_iter', cur_iter)
                logger.add_logs(cur_iter, train_logs, pref='train_')
                logger.add_scalar(
                    cur_iter, 'lr', model.optimizer.param_groups[0]['lr'])
                logger.add_scalar(cur_iter, 'data_time', data_time)
                logger.add_scalar(cur_iter, 'it_time', it_time)
                logger.iter_info()
                logger.save()
                data_time, it_time = 0, 0
                train_logs = []
            if cur_iter >= args.iters:
                continue_training = False
                break
            start_time = time.time()
Пример #3
0
def main_worker(gpu, ngpus, args):
    fmt = {
        'train_time': '.3f',
        'val_time': '.3f',
        'lr': '.1e',
    }
    logger = Logger('logs', base=args.root, fmt=fmt)

    args.gpu = gpu
    torch.cuda.set_device(gpu)
    args.rank = args.node_rank * ngpus + gpu

    device = torch.device('cuda:%d' % args.gpu)

    if args.dist == 'ddp':
        dist.init_process_group(
            backend='nccl',
            init_method='tcp://%s' % args.dist_address,
            world_size=args.world_size,
            rank=args.rank,
        )

        n_gpus_total = dist.get_world_size()
        assert args.batch_size % n_gpus_total == 0
        args.batch_size //= n_gpus_total
        if args.rank == 0:
            print(
                f'===> {n_gpus_total} GPUs total; batch_size={args.batch_size} per GPU'
            )

        print(
            f'===> Proc {dist.get_rank()}/{dist.get_world_size()}@{socket.gethostname()}',
            flush=True)

    # create model
    model = models.REGISTERED_MODELS[args.problem](args, device=device)

    if args.ckpt != '':
        ckpt = torch.load(args.ckpt, map_location=device)
        model.load_state_dict(ckpt['state_dict'])

    # Data loading code
    model.prepare_data()
    train_loader, val_loader = model.dataloaders(iters=args.iters)

    # define optimizer
    cur_iter = 0
    optimizer, scheduler = models.ssl.configure_optimizers(
        args, model, cur_iter - 1)

    # optionally resume from a checkpoint
    if args.ckpt and not args.eval_only:
        optimizer.load_state_dict(ckpt['opt_state_dict'])

    cudnn.benchmark = True

    continue_training = args.iters != 0
    data_time, it_time = 0, 0

    while continue_training:
        train_logs = []
        model.train()

        start_time = time.time()
        for _, batch in enumerate(train_loader):
            cur_iter += 1

            batch = [x.to(device) for x in batch]
            data_time += time.time() - start_time

            logs = {}
            if not args.eval_only:
                # forward pass and compute loss
                logs = model.train_step(batch, cur_iter)
                loss = logs['loss']

                # gradient step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # save logs for the batch
            train_logs.append({k: utils.tonp(v) for k, v in logs.items()})

            if cur_iter % args.save_freq == 0 and args.rank == 0:
                save_checkpoint(args.root, model, optimizer, cur_iter)

            if cur_iter % args.eval_freq == 0 or cur_iter >= args.iters:
                # TODO: aggregate metrics over all processes
                test_logs = []
                model.eval()
                with torch.no_grad():
                    for batch in val_loader:
                        batch = [x.to(device) for x in batch]
                        # forward pass
                        logs = model.test_step(batch)
                        # save logs for the batch
                        test_logs.append(logs)
                model.train()

                test_logs = utils.agg_all_metrics(test_logs)
                logger.add_logs(cur_iter, test_logs, pref='test_')

            it_time += time.time() - start_time

            if (cur_iter % args.log_freq == 0
                    or cur_iter >= args.iters) and args.rank == 0:
                save_checkpoint(args.root, model, optimizer)
                train_logs = utils.agg_all_metrics(train_logs)

                logger.add_logs(cur_iter, train_logs, pref='train_')
                logger.add_scalar(cur_iter, 'lr',
                                  optimizer.param_groups[0]['lr'])
                logger.add_scalar(cur_iter, 'data_time', data_time)
                logger.add_scalar(cur_iter, 'it_time', it_time)
                logger.iter_info()
                logger.save()

                data_time, it_time = 0, 0
                train_logs = []

            if scheduler is not None:
                scheduler.step()

            if cur_iter >= args.iters:
                continue_training = False
                break

            start_time = time.time()

    save_checkpoint(args.root, model, optimizer)

    if args.dist == 'ddp':
        dist.destroy_process_group()
Пример #4
0
def main(opts):
    distributed.init_process_group(backend='nccl', init_method='env://')
    device_id, device = opts.local_rank, torch.device(opts.local_rank)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    torch.cuda.set_device(device_id)

    # Initialize logging
    task_name = f"{opts.task}-{opts.dataset}"
    logdir_full = f"{opts.logdir}/{task_name}/{opts.name}/"
    if rank == 0:
        logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=opts.visualize, step=opts.step)
    else:
        logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=False)

    logger.print(f"Device: {device}")

    # Set up random seed
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # xxx Set up dataloader
    train_dst, val_dst, test_dst, n_classes = get_dataset(opts)
    # reset the seed, this revert changes in random seed
    random.seed(opts.random_seed)

    train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size,
                                   sampler=DistributedSampler(train_dst, num_replicas=world_size, rank=rank),
                                   num_workers=opts.num_workers, drop_last=True)
    val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size if opts.crop_val else 1,
                                 sampler=DistributedSampler(val_dst, num_replicas=world_size, rank=rank),
                                 num_workers=opts.num_workers)
    logger.info(f"Dataset: {opts.dataset}, Train set: {len(train_dst)}, Val set: {len(val_dst)},"
                f" Test set: {len(test_dst)}, n_classes {n_classes}")
    logger.info(f"Total batch size is {opts.batch_size * world_size}")

    # xxx Set up model
    logger.info(f"Backbone: {opts.backbone}")

    step_checkpoint = None
    model = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step))
    logger.info(f"[!] Model made with{'out' if opts.no_pretrained else ''} pre-trained")

    if opts.step == 0:  # if step 0, we don't need to instance the model_old
        model_old = None
    else:  # instance model_old
        model_old = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step - 1))

    if opts.fix_bn:
        model.fix_bn()

    logger.debug(model)

    # xxx Set up optimizer
    params = []
    if not opts.freeze:
        params.append({"params": filter(lambda p: p.requires_grad, model.body.parameters()),
                       'weight_decay': opts.weight_decay})

    params.append({"params": filter(lambda p: p.requires_grad, model.head.parameters()),
                   'weight_decay': opts.weight_decay})

    params.append({"params": filter(lambda p: p.requires_grad, model.cls.parameters()),
                   'weight_decay': opts.weight_decay})

    optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=True)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, max_iters=opts.epochs * len(train_loader), power=opts.lr_power)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    else:
        raise NotImplementedError
    logger.debug("Optimizer:\n%s" % optimizer)

    if model_old is not None:
        [model, model_old], optimizer = amp.initialize([model.to(device), model_old.to(device)], optimizer,
                                                       opt_level=opts.opt_level)
        model_old = DistributedDataParallel(model_old)
    else:
        model, optimizer = amp.initialize(model.to(device), optimizer, opt_level=opts.opt_level)

    # Put the model on GPU
    model = DistributedDataParallel(model, delay_allreduce=True)

    # xxx Load old model from old weights if step > 0!
    if opts.step > 0:
        # get model path
        if opts.step_ckpt is not None:
            path = opts.step_ckpt
        else:
            path = f"checkpoints/step/{task_name}_{opts.name}_{opts.step - 1}.pth"

        # generate model from path
        if os.path.exists(path):
            step_checkpoint = torch.load(path, map_location="cpu")
            model.load_state_dict(step_checkpoint['model_state'], strict=False)  # False because of incr. classifiers
            if opts.init_balanced:
                # implement the balanced initialization (new cls has weight of background and bias = bias_bkg - log(N+1)
                model.module.init_new_classifier(device)
            # Load state dict from the model state dict, that contains the old model parameters
            model_old.load_state_dict(step_checkpoint['model_state'], strict=True)  # Load also here old parameters
            logger.info(f"[!] Previous model loaded from {path}")
            # clean memory
            del step_checkpoint['model_state']
        elif opts.debug:
            logger.info(f"[!] WARNING: Unable to find of step {opts.step - 1}! Do you really want to do from scratch?")
        else:
            raise FileNotFoundError(path)
        # put the old model into distributed memory and freeze it
        for par in model_old.parameters():
            par.requires_grad = False
        model_old.eval()

    # xxx Set up Trainer
    trainer_state = None
    # if not first step, then instance trainer from step_checkpoint
    if opts.step > 0 and step_checkpoint is not None:
        if 'trainer_state' in step_checkpoint:
            trainer_state = step_checkpoint['trainer_state']

    # instance trainer (model must have already the previous step weights)
    trainer = Trainer(model, model_old, device=device, opts=opts, trainer_state=trainer_state,
                      classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step))

    # xxx Handle checkpoint for current model (model old will always be as previous step or None)
    best_score = 0.0
    cur_epoch = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location="cpu")
        model.load_state_dict(checkpoint["model_state"], strict=True)
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        scheduler.load_state_dict(checkpoint["scheduler_state"])
        cur_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint['best_score']
        logger.info("[!] Model restored from %s" % opts.ckpt)
        # if we want to resume training, resume trainer from checkpoint
        if 'trainer_state' in checkpoint:
            trainer.load_state_dict(checkpoint['trainer_state'])
        del checkpoint
    else:
        if opts.step == 0:
            logger.info("[!] Train from scratch")

    # xxx Train procedure
    # print opts before starting training to log all parameters
    logger.add_table("Opts", vars(opts))

    if rank == 0 and opts.sample_num > 0:
        sample_ids = np.random.choice(len(val_loader), opts.sample_num, replace=False)  # sample idxs for visualization
        logger.info(f"The samples id are {sample_ids}")
    else:
        sample_ids = None

    label2color = utils.Label2Color(cmap=utils.color_map(opts.dataset))  # convert labels to images
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])  # de-normalization for original images

    TRAIN = not opts.test
    val_metrics = StreamSegMetrics(n_classes)
    results = {}

    # check if random is equal here.
    logger.print(torch.randint(0,100, (1,1)))
    # train/val here
    while cur_epoch < opts.epochs and TRAIN:
        # =====  Train  =====
        model.train()

        epoch_loss = trainer.train(cur_epoch=cur_epoch, optim=optimizer,
                                   train_loader=train_loader, scheduler=scheduler, logger=logger)

        logger.info(f"End of Epoch {cur_epoch}/{opts.epochs}, Average Loss={epoch_loss[0]+epoch_loss[1]},"
                    f" Class Loss={epoch_loss[0]}, Reg Loss={epoch_loss[1]}")

        # =====  Log metrics on Tensorboard =====
        logger.add_scalar("E-Loss", epoch_loss[0]+epoch_loss[1], cur_epoch)
        logger.add_scalar("E-Loss-reg", epoch_loss[1], cur_epoch)
        logger.add_scalar("E-Loss-cls", epoch_loss[0], cur_epoch)

        # =====  Validation  =====
        if (cur_epoch + 1) % opts.val_interval == 0:
            logger.info("validate on val set...")
            model.eval()
            val_loss, val_score, ret_samples = trainer.validate(loader=val_loader, metrics=val_metrics,
                                                                ret_samples_ids=sample_ids, logger=logger)

            logger.print("Done validation")
            logger.info(f"End of Validation {cur_epoch}/{opts.epochs}, Validation Loss={val_loss[0]+val_loss[1]},"
                        f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}")

            logger.info(val_metrics.to_str(val_score))

            # =====  Save Best Model  =====
            if rank == 0:  # save best model at the last iteration
                score = val_score['Mean IoU']
                # best model to build incremental steps
                save_ckpt(f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth",
                          model, trainer, optimizer, scheduler, cur_epoch, score)
                logger.info("[!] Checkpoint saved.")

            # =====  Log metrics on Tensorboard =====
            # visualize validation score and samples
            logger.add_scalar("V-Loss", val_loss[0]+val_loss[1], cur_epoch)
            logger.add_scalar("V-Loss-reg", val_loss[1], cur_epoch)
            logger.add_scalar("V-Loss-cls", val_loss[0], cur_epoch)
            logger.add_scalar("Val_Overall_Acc", val_score['Overall Acc'], cur_epoch)
            logger.add_scalar("Val_MeanIoU", val_score['Mean IoU'], cur_epoch)
            logger.add_table("Val_Class_IoU", val_score['Class IoU'], cur_epoch)
            logger.add_table("Val_Acc_IoU", val_score['Class Acc'], cur_epoch)
            # logger.add_figure("Val_Confusion_Matrix", val_score['Confusion Matrix'], cur_epoch)

            # keep the metric to print them at the end of training
            results["V-IoU"] = val_score['Class IoU']
            results["V-Acc"] = val_score['Class Acc']

            for k, (img, target, lbl) in enumerate(ret_samples):
                img = (denorm(img) * 255).astype(np.uint8)
                target = label2color(target).transpose(2, 0, 1).astype(np.uint8)
                lbl = label2color(lbl).transpose(2, 0, 1).astype(np.uint8)

                concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
                logger.add_image(f'Sample_{k}', concat_img, cur_epoch)

        cur_epoch += 1

    # =====  Save Best Model at the end of training =====
    if rank == 0 and TRAIN:  # save best model at the last iteration
        # best model to build incremental steps
        save_ckpt(f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth",
                  model, trainer, optimizer, scheduler, cur_epoch, best_score)
        logger.info("[!] Checkpoint saved.")

    torch.distributed.barrier()

    # xxx From here starts the test code
    logger.info("*** Test the model on all seen classes...")
    # make data loader
    test_loader = data.DataLoader(test_dst, batch_size=opts.batch_size if opts.crop_val else 1,
                                  sampler=DistributedSampler(test_dst, num_replicas=world_size, rank=rank),
                                  num_workers=opts.num_workers)

    # load best model
    if TRAIN:
        model = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step))
        # Put the model on GPU
        model = DistributedDataParallel(model.cuda(device))
        ckpt = f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth"
        checkpoint = torch.load(ckpt, map_location="cpu")
        model.load_state_dict(checkpoint["model_state"])
        logger.info(f"*** Model restored from {ckpt}")
        del checkpoint
        trainer = Trainer(model, None, device=device, opts=opts)

    model.eval()

    val_loss, val_score, _ = trainer.validate(loader=test_loader, metrics=val_metrics, logger=logger)
    logger.print("Done test")
    logger.info(f"*** End of Test, Total Loss={val_loss[0]+val_loss[1]},"
                f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}")
    logger.info(val_metrics.to_str(val_score))
    logger.add_table("Test_Class_IoU", val_score['Class IoU'])
    logger.add_table("Test_Class_Acc", val_score['Class Acc'])
    logger.add_figure("Test_Confusion_Matrix", val_score['Confusion Matrix'])
    results["T-IoU"] = val_score['Class IoU']
    results["T-Acc"] = val_score['Class Acc']
    logger.add_results(results)

    logger.add_scalar("T_Overall_Acc", val_score['Overall Acc'], opts.step)
    logger.add_scalar("T_MeanIoU", val_score['Mean IoU'], opts.step)
    logger.add_scalar("T_MeanAcc", val_score['Mean Acc'], opts.step)

    logger.close()
Пример #5
0
def train(args):
    # Configuration
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    input_height, input_width = args.input_size

    logger = Logger(log_root='logs/', name=args.logger_name)

    for k, v in args.__dict__.items():
        logger.add_text('configuration', "{}: {}".format(k, v))

    # Dataset
    train_loader, val_loader = get_data_loaders(args)
    batchs_in_val = math.ceil(len(val_loader.dataset) / args.validate_batch)
    print("Train set size:", len(train_loader.dataset))
    print("Val set size:", len(val_loader.dataset))

    # Network
    if args.use_noise is True:
        noise_layers = {
            'crop': '((0.4,0.55),(0.4,0.55))',
            'cropout': '((0.25,0.35),(0.25,0.35))',
            'dropout': '(0.25,0.35)',
            'jpeg': '()',
            'resize': '(0.4,0.6)',
        }  # This is a combined noise used in the paper
    else:
        noise_layers = dict()
    encoder = Encoder(input_height, input_width, args.info_size)
    noiser = Noiser(noise_layers, torch.device('cuda'))
    decoder = Decoder(args.info_size)
    discriminator = Discriminator()
    encoder.cuda()
    noiser.cuda()
    decoder.cuda()
    discriminator.cuda()

    # Optimizers
    optimizer_enc = torch.optim.Adam(encoder.parameters())
    optimizer_dec = torch.optim.Adam(decoder.parameters())
    optimizer_dis = torch.optim.Adam(discriminator.parameters())

    # Training
    dir_save = 'ckpt/{}'.format(logger.log_name)
    os.makedirs(dir_save, exist_ok=True)
    os.makedirs(dir_save + '/images/', exist_ok=True)
    os.makedirs(dir_save + '/models/', exist_ok=True)
    training_losses = defaultdict(AverageLoss)

    info_fix = torch.randint(0, 2, size=(100, args.info_size)).to(device, dtype=torch.float32)
    image_fix = None
    for image, _ in val_loader:
        image_fix = image.cuda()  # 100 images for validate, the first batch
        break
    global_step = 1
    for epoch in range(1, args.epochs + 1):

        # Train one epoch
        for image, _ in train_loader:
            image = image.cuda()
            batch_size = image.shape[0]
            info = torch.randint(0, 2, size=(batch_size, args.info_size)).to(device, dtype=torch.float32)

            encoder.train()
            noiser.train()
            decoder.train()
            discriminator.train()
            # ---------------- Train the discriminator -----------------------------
            optimizer_dis.zero_grad()
            # train on cover
            y_real = torch.ones(batch_size, 1).cuda()
            y_fake = torch.zeros(batch_size, 1).cuda()

            d_on_cover = discriminator(image)
            encoded_image = encoder(image, info)
            d_on_encoded = discriminator(encoded_image.detach())

            if args.relative_loss:
                d_loss_on_cover = F.binary_cross_entropy_with_logits(d_on_cover - torch.mean(d_on_encoded),
                                                                     y_real)
                d_loss_on_encoded = F.binary_cross_entropy_with_logits(d_on_encoded - torch.mean(d_on_cover),
                                                                       y_fake)
                d_loss = d_loss_on_cover + d_loss_on_encoded
            else:
                d_loss_on_cover = F.binary_cross_entropy_with_logits(d_on_cover, y_real)
                d_loss_on_encoded = F.binary_cross_entropy_with_logits(d_on_encoded, y_fake)

                d_loss = d_loss_on_cover + d_loss_on_encoded

            d_loss.backward()
            optimizer_dis.step()

            # --------------Train the generator (encoder-decoder) ---------------------
            optimizer_enc.zero_grad()
            optimizer_dec.zero_grad()

            d_on_cover = discriminator(image)
            encoded_image = encoder(image, info)
            noised_and_cover = noiser([encoded_image, image])
            noised_image = noised_and_cover[0]
            decoded_info = decoder(noised_image)
            d_on_encoded = discriminator(encoded_image)
            if args.relative_loss:
                g_loss_adv = \
                    (F.binary_cross_entropy_with_logits(d_on_encoded - torch.mean(d_on_cover), y_real) +
                     F.binary_cross_entropy_with_logits(d_on_cover - torch.mean(d_on_encoded), y_fake)) * 0.5
                g_loss_enc = F.mse_loss(encoded_image, image)
                g_loss_dec = F.mse_loss(decoded_info, info)
            else:
                g_loss_adv = F.binary_cross_entropy_with_logits(d_on_encoded, y_real)
                g_loss_enc = F.mse_loss(encoded_image, image)
                g_loss_dec = F.mse_loss(decoded_info, info)
            g_loss = args.adversarial_loss_constant * g_loss_adv + \
                     args.encoder_loss_constant * g_loss_enc + \
                     args.decoder_loss_constant * g_loss_dec

            g_loss.backward()
            optimizer_enc.step()
            optimizer_dec.step()

            decoded_rounded = decoded_info.detach().cpu().numpy().round().clip(0, 1)
            bitwise_avg_err = \
                np.sum(np.abs(decoded_rounded - info.detach().cpu().numpy())) / \
                (batch_size * info.shape[1])

            losses = {
                'g_loss': g_loss.item(),
                'g_loss_enc': g_loss_enc.item(),
                'g_loss_dec': g_loss_dec.item(),
                'bitwise_avg_error': bitwise_avg_err,
                'g_loss_adv': g_loss_adv.item(),
                'd_loss_on_cover': d_loss_on_cover.item(),
                'd_loss_on_encoded': d_loss_on_encoded.item(),
                'd_loss': d_loss.item()
            }
            if logger:
                for name, loss in losses.items():
                    logger.add_scalar(name + '_iter', loss, global_step)
                    training_losses[name].update(loss)
            global_step += 1

        if logger:
            logger.add_scalar('d_loss_epoch', training_losses['d_loss'].avg, epoch)
            logger.add_scalar('g_loss_epoch', training_losses['g_loss'].avg, epoch)

        # Validate each epoch
        info_random = torch.randint(0, 2, size=(100, args.info_size)).to(device, dtype=torch.float32)
        image_random = None
        choice = random.randint(0, batchs_in_val - 2)
        # print(choice)
        for i, (image, _) in enumerate(val_loader):
            if i < choice:
                continue
            if image.shape[0] < 100:
                continue
            image_random = image.cuda()  # Grub the first batch
            break

        encoder.eval()
        noiser.eval()
        decoder.eval()
        discriminator.eval()

        encoded_image_random = encoder(image_random, info_random)
        noised_and_cover_random = noiser([encoded_image_random, image_random])
        noised_image_random = noised_and_cover_random[0]
        decoded_info_random = decoder(noised_image_random)

        encoded_image_fix = encoder(image_fix, info_fix)
        noised_and_cover_fix = noiser([encoded_image_fix, image_fix])
        noised_image_fix = noised_and_cover_fix[0]
        decoded_info_fix = decoder(noised_image_fix)

        decoded_rounded_fix = decoded_info_fix.detach().cpu().numpy().round().clip(0, 1)
        bitwise_avg_err_fix = \
            np.sum(np.abs(decoded_rounded_fix - info_fix.detach().cpu().numpy())) / \
            (100 * info_fix.shape[1])

        decoded_rounded_random = decoded_info_random.detach().cpu().numpy().round().clip(0, 1)
        bitwise_avg_err_random = \
            np.sum(np.abs(decoded_rounded_random - info_random.detach().cpu().numpy())) / \
            (100 * info_random.shape[1])

        stack_image_random = exec_val(image_random, encoded_image_random,
                                      os.path.join(dir_save, 'images', 'random_epoch{:0>3d}.png'.format(epoch)))
        stack_image_fix = exec_val(image_fix, encoded_image_fix,
                                   os.path.join(dir_save, 'images', 'fix_epoch{:0>3d}.png'.format(epoch)))
        if logger:
            logger.add_scalar('fix_err_ratio', bitwise_avg_err_fix, epoch)
            logger.add_scalar('random_err_ratio', bitwise_avg_err_random, epoch)
            logger.add_image('image_rand', stack_image_random, epoch)
            logger.add_image('image_fix', stack_image_fix, epoch)
        torch.save(encoder.state_dict(), '{}/models/encoder-epoch{:0>3d}.pth'.format(dir_save, epoch))
        torch.save(decoder.state_dict(), '{}/models/decoder-epoch{:0>3d}.pth'.format(dir_save, epoch))
        if args.use_noise:
            torch.save(noiser.state_dict(), '{}/models/noiser-epoch{:0>3d}.pth'.format(dir_save, epoch))
        torch.save(discriminator.state_dict(), '{}/models/discriminator-epoch{:0>3d}.pth'.format(dir_save, epoch))