Пример #1
0
def main():
    """Create the model and start the training."""

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    writer = SummaryWriter(args.snapshot_dir)
    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]

    cudnn.enabled = True
    # cudnn related setting
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True

    deeplab = Res_Deeplab(num_classes=args.num_classes)

    # dump_input = torch.rand((args.batch_size, 3, input_size[0], input_size[1]))
    # writer.add_graph(deeplab.cuda(), dump_input.cuda(), verbose=False)

    saved_state_dict = torch.load(args.restore_from)
    new_params = deeplab.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        # print(i_parts)
        if not i_parts[0] == 'fc':
            new_params['.'.join(i_parts[0:])] = saved_state_dict[i]

    deeplab.load_state_dict(new_params)

    model = DataParallelModel(deeplab)
    model.cuda()

    criterion = CriterionAll()
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    trainloader = data.DataLoader(LIPDataSet(args.data_dir,
                                             args.dataset,
                                             crop_size=input_size,
                                             transform=transform),
                                  batch_size=args.batch_size * len(gpus),
                                  shuffle=True,
                                  num_workers=2,
                                  pin_memory=True)
    #lip_dataset = LIPDataSet(args.data_dir, 'val', crop_size=input_size, transform=transform)
    #num_samples = len(lip_dataset)

    #valloader = data.DataLoader(lip_dataset, batch_size=args.batch_size * len(gpus),
    #                             shuffle=False, pin_memory=True)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    total_iters = args.epochs * len(trainloader)
    for epoch in range(args.start_epoch, args.epochs):
        model.train()
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch
            lr = adjust_learning_rate(optimizer, i_iter, total_iters)

            images, labels, edges, _ = batch
            labels = labels.long().cuda(non_blocking=True)
            edges = edges.long().cuda(non_blocking=True)

            preds = model(images)

            loss = criterion(preds, [labels, edges])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i_iter % 100 == 0:
                writer.add_scalar('learning_rate', lr, i_iter)
                writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

            if i_iter % 500 == 0:

                images_inv = inv_preprocess(images, args.save_num_images)
                labels_colors = decode_parsing(labels,
                                               args.save_num_images,
                                               args.num_classes,
                                               is_pred=False)
                edges_colors = decode_parsing(edges,
                                              args.save_num_images,
                                              2,
                                              is_pred=False)

                if isinstance(preds, list):
                    preds = preds[0]
                preds_colors = decode_parsing(preds[0][-1],
                                              args.save_num_images,
                                              args.num_classes,
                                              is_pred=True)
                pred_edges = decode_parsing(preds[1][-1],
                                            args.save_num_images,
                                            2,
                                            is_pred=True)

                img = vutils.make_grid(images_inv,
                                       normalize=False,
                                       scale_each=True)
                lab = vutils.make_grid(labels_colors,
                                       normalize=False,
                                       scale_each=True)
                pred = vutils.make_grid(preds_colors,
                                        normalize=False,
                                        scale_each=True)
                edge = vutils.make_grid(edges_colors,
                                        normalize=False,
                                        scale_each=True)
                pred_edge = vutils.make_grid(pred_edges,
                                             normalize=False,
                                             scale_each=True)

                writer.add_image('Images/', img, i_iter)
                writer.add_image('Labels/', lab, i_iter)
                writer.add_image('Preds/', pred, i_iter)
                writer.add_image('Edges/', edge, i_iter)
                writer.add_image('PredEdges/', pred_edge, i_iter)

            print('iter = {} of {} completed, loss = {}'.format(
                i_iter, total_iters,
                loss.data.cpu().numpy()))

        torch.save(
            model.state_dict(),
            osp.join(args.snapshot_dir, 'LIP_epoch_' + str(epoch) + '.pth'))

        #parsing_preds, scales, centers = valid(model, valloader, input_size,  num_samples, len(gpus))

        #mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size)

        #print(mIoU)
        #writer.add_scalars('mIoU', mIoU, epoch)

    end = timeit.default_timer()
    print(end - start, 'seconds')
Пример #2
0
def main():
    """Create the model and start the training."""

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]
    best_f1 = 0

    torch.cuda.set_device(args.local_rank)

    try:
        world_size = int(os.environ['WORLD_SIZE'])
        distributed = world_size > 1
    except:
        distributed = False
        world_size = 1
    if distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method='env://')
    rank = 0 if not distributed else dist.get_rank()

    writer = SummaryWriter(osp.join(args.snapshot_dir,
                                    TIMESTAMP)) if rank == 0 else None

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    if args.type == 'Helen':
        train_dataset = HelenDataSet('dataset/Helen_align_with_hair',
                                     args.dataset,
                                     crop_size=input_size,
                                     transform=transform)
        val_dataset = HelenDataSet('dataset/Helen_align_with_hair',
                                   'test',
                                   crop_size=input_size,
                                   transform=transform)
        args.num_classes = 11
    elif args.type == 'LaPa':
        train_dataset = LapaDataset('dataset/LaPa/origin',
                                    args.dataset,
                                    crop_size=input_size,
                                    transform=transform)
        val_dataset = LapaDataset('dataset/LaPa/origin',
                                  'test',
                                  crop_size=input_size,
                                  transform=transform)
        args.num_classes = 11
    elif args.type == 'Celeb':
        train_dataset = CelebAMaskHQDataSet('dataset/CelebAMask-HQ',
                                            args.dataset,
                                            crop_size=input_size,
                                            transform=transform)
        val_dataset = CelebAMaskHQDataSet('dataset/CelebAMask-HQ',
                                          'test',
                                          crop_size=input_size,
                                          transform=transform)
        args.num_classes = 19
    elif args.type == 'LIP':
        train_dataset = LIPDataSet('dataset/LIP',
                                   args.dataset,
                                   crop_size=input_size,
                                   transform=transform)
        val_dataset = LIPDataSet('dataset/LIP',
                                 'val',
                                 crop_size=input_size,
                                 transform=transform)
        args.num_classes = 20

    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=2,
                                  pin_memory=True,
                                  drop_last=True,
                                  sampler=train_sampler)

    num_samples = len(val_dataset)

    valloader = data.DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                pin_memory=True,
                                drop_last=False)

    cudnn.enabled = True
    # cudnn related setting
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True

    if distributed:
        model = AGRNet(args.num_classes)
    else:
        model = AGRNet(args.num_classes, InPlaceABN)

    if args.restore_from is not None:
        model.load_state_dict(
            torch.load(args.restore_from,
                       map_location='cuda:{}'.format(args.local_rank)), True)
    else:
        resnet_params = torch.load(
            os.path.join(args.snapshot_dir, 'resnet101-imagenet.pth'))
        new_params = model.state_dict().copy()
        for i in resnet_params:
            i_parts = i.split('.')
            # print(i_parts)
            if not i_parts[0] == 'fc':
                new_params['.'.join(i_parts[0:])] = resnet_params[i]
        model.load_state_dict(new_params)
    model.cuda()
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
    else:
        model = SingleGPU(model)

    # CriterionCrossEntropyEdgeParsing_boundary_agrnet_loss for AGRNet, CriterionCrossEntropyEdgeParsing_boundary_eagrnet_loss for EAGRNet
    criterion = CriterionCrossEntropyEdgeParsing_boundary_agrnet_loss(
        loss_weight=[args.l1, args.l2, args.l3, args.l4],
        num_classes=args.num_classes)
    criterion.cuda()

    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    total_iters = args.epochs * len(trainloader)
    for epoch in range(args.start_epoch, args.epochs):
        model.train()
        if distributed:
            train_sampler.set_epoch(epoch)
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch
            lr = adjust_learning_rate(optimizer, i_iter, total_iters)

            images, labels, edges, _ = batch
            labels = labels.long().cuda(non_blocking=True)
            edges = edges.long().cuda(non_blocking=True)

            preds = model(images)

            loss = criterion(preds, [labels, edges])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                loss = loss.detach() * labels.shape[0]
                count = labels.new_tensor([labels.shape[0]], dtype=torch.long)
                if dist.is_initialized():
                    dist.all_reduce(count, dist.ReduceOp.SUM)
                    dist.all_reduce(loss, dist.ReduceOp.SUM)
                loss /= count.item()

            if not dist.is_initialized() or dist.get_rank() == 0:
                if i_iter % 50 == 0:
                    writer.add_scalar('learning_rate', lr, i_iter)
                    writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

                if i_iter % 500 == 0:

                    images_inv = inv_preprocess(images, args.save_num_images)
                    labels_colors = decode_parsing(labels,
                                                   args.save_num_images,
                                                   args.num_classes,
                                                   is_pred=False)
                    edges_colors = decode_parsing(edges,
                                                  args.save_num_images,
                                                  2,
                                                  is_pred=False)

                    if isinstance(preds, list):
                        preds = preds[0]
                    preds_colors = decode_parsing(preds[0],
                                                  args.save_num_images,
                                                  args.num_classes,
                                                  is_pred=True)
                    pred_edges = decode_parsing(preds[1],
                                                args.save_num_images,
                                                2,
                                                is_pred=True)

                    img = vutils.make_grid(images_inv,
                                           normalize=False,
                                           scale_each=True)
                    lab = vutils.make_grid(labels_colors,
                                           normalize=False,
                                           scale_each=True)
                    pred = vutils.make_grid(preds_colors,
                                            normalize=False,
                                            scale_each=True)
                    edge = vutils.make_grid(edges_colors,
                                            normalize=False,
                                            scale_each=True)
                    pred_edge = vutils.make_grid(pred_edges,
                                                 normalize=False,
                                                 scale_each=True)

                    writer.add_image('Images/', img, i_iter)
                    writer.add_image('Labels/', lab, i_iter)
                    writer.add_image('Preds/', pred, i_iter)
                    writer.add_image('Edge/', edge, i_iter)
                    writer.add_image('Pred_edge/', pred_edge, i_iter)

                print('iter = {} of {} completed, loss = {}'.format(
                    i_iter, total_iters,
                    loss.data.cpu().numpy()))
        if not dist.is_initialized() or dist.get_rank() == 0:
            save_path = os.path.join(args.data_dir, TIMESTAMP)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            parsing_preds, scales, centers = valid(
                model, valloader, input_size, num_samples,
                osp.join(args.snapshot_dir, save_path))
            mIoU, f1 = compute_mean_ioU(parsing_preds,
                                        scales,
                                        centers,
                                        args.num_classes,
                                        val_dataset,
                                        input_size,
                                        'test',
                                        True,
                                        type=args.type)
            if f1['mean'] > best_f1:
                torch.save(model.module.state_dict(),
                           osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth'))
                best_f1 = f1['mean']
            print(mIoU)
            print(f1)
            writer.add_scalars('mIoU', mIoU, epoch)
            writer.add_scalars('f1', f1, epoch)

            if epoch % args.test_fre == 0:
                torch.save(
                    model.module.state_dict(),
                    osp.join(args.snapshot_dir, TIMESTAMP,
                             'epoch_' + str(epoch) + '.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Пример #3
0
def main():
    """Create the model and start the training."""
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    writer = SummaryWriter(args.snapshot_dir)
    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    h, w = [int(i) for i in args.input_size.split(',')]
    input_size = [h, w]
    cudnn.enabled = True
    # cudnn related setting
    cudnn.benchmark = False
    torch.backends.cudnn.deterministic = False  ##为何使用了非确定性的卷积
    torch.backends.cudnn.enabled = True
    NUM_CLASSES = 7  # parsing
    NUM_HEATMAP = 15  # pose
    NUM_PAFS = 28  # pafs
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    # load dataset
    num_samples = 0
    trainloader = data.DataLoader(VOCSegmentation(args.data_dir,
                                                  args.dataset,
                                                  crop_size=input_size,
                                                  stride=args.stride,
                                                  transform=transform),
                                  batch_size=args.batch_size * len(gpus),
                                  shuffle=True,
                                  num_workers=2,
                                  pin_memory=True)

    valloader = None
    if args.print_val != 0:
        valdataset = VOCSegmentation(args.data_dir,
                                     'val',
                                     crop_size=input_size,
                                     transform=transform)
        num_samples = len(valdataset)
        valloader = data.DataLoader(
            valdataset,
            batch_size=8 * len(gpus),  # batchsize
            shuffle=False,
            pin_memory=True)

    parsingnet = ParsingNet(num_classes=NUM_CLASSES,
                            num_heatmaps=NUM_HEATMAP,
                            num_pafs=NUM_PAFS)
    criterion_parsing = Criterion()
    criterion_parsing = DataParallelCriterion(criterion_parsing)
    criterion_parsing.cuda()

    optimizer_parsing = optim.SGD(parsingnet.parameters(),
                                  lr=args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)

    optimizer_parsing.zero_grad()
    # 加载预训练参数
    print(args.train_continue)
    if not args.train_continue:
        checkpoint = torch.load(RESNET_IMAGENET)
        load_state(parsingnet, checkpoint)
    else:
        checkpoint = torch.load(args.restore_from_parsing)
        if 'current_epoch' in checkpoint:
            current_epoch = checkpoint['current_epoch']
            args.start_epoch = current_epoch

        if 'state_dict' in checkpoint:
            checkpoint = checkpoint['state_dict']

        load_state(parsingnet, checkpoint)

    parsingnet = DataParallelModel(parsingnet).cuda()
    total_iters = args.epochs * len(trainloader)
    for epoch in range(args.start_epoch, args.epochs):
        parsingnet.train()
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch
            lr = adjust_parsing_lr(optimizer_parsing, i_iter, total_iters)

            images, labels, edges, heatmap, pafs, heatmap_mask, pafs_mask, _ = batch
            images = images.cuda()
            labels = labels.long().cuda(non_blocking=True)
            edges = edges.long().cuda(non_blocking=True)
            heatmap = heatmap.cuda()
            pafs = pafs.cuda()
            heatmap_mask = heatmap_mask.cuda()
            pafs_mask = pafs_mask.cuda()

            preds = parsingnet(images)
            loss_parsing = criterion_parsing(
                preds, [labels, edges, heatmap, pafs, heatmap_mask, pafs_mask],
                writer, i_iter, total_iters)
            optimizer_parsing.zero_grad()
            loss_parsing.backward()
            optimizer_parsing.step()
            if i_iter % 100 == 0:
                writer.add_scalar('parsing_lr', lr, i_iter)
                writer.add_scalar('loss_total', loss_parsing.item(), i_iter)
            if i_iter % 500 == 0:

                if len(gpus) > 1:
                    preds = preds[0]

                images_inv = inv_preprocess(images, args.save_num_images)
                parsing_labels_c = decode_parsing(labels,
                                                  args.save_num_images,
                                                  is_pred=False)
                preds_colors = decode_parsing(preds[0][-1],
                                              args.save_num_images,
                                              is_pred=True)
                edges_colors = decode_parsing(edges,
                                              args.save_num_images,
                                              is_pred=False)
                pred_edges = decode_parsing(preds[1][-1],
                                            args.save_num_images,
                                            is_pred=True)

                img = vutils.make_grid(images_inv,
                                       normalize=False,
                                       scale_each=True)
                parsing_lab = vutils.make_grid(parsing_labels_c,
                                               normalize=False,
                                               scale_each=True)
                pred_v = vutils.make_grid(preds_colors,
                                          normalize=False,
                                          scale_each=True)
                edge = vutils.make_grid(edges_colors,
                                        normalize=False,
                                        scale_each=True)
                pred_edges = vutils.make_grid(pred_edges,
                                              normalize=False,
                                              scale_each=True)

                writer.add_image('Images/', img, i_iter)
                writer.add_image('Parsing_labels/', parsing_lab, i_iter)
                writer.add_image('Parsing_Preds/', pred_v, i_iter)

                writer.add_image('Edges/', edge, i_iter)
                writer.add_image('Edges_preds/', pred_edges, i_iter)

        if (epoch + 1) % 15 == 0:
            if args.print_val != 0:
                parsing_preds, scales, centers = valid(parsingnet, valloader,
                                                       input_size, num_samples,
                                                       gpus)
                mIoU = compute_mean_ioU(parsing_preds, scales, centers,
                                        NUM_CLASSES, args.data_dir, input_size)
                f = open(os.path.join(args.snapshot_dir, "val_res.txt"), "a+")
                f.write(str(epoch) + str(mIoU) + '\n')
                f.close()
            snapshot_name_parsing = osp.join(
                args.snapshot_dir,
                'PASCAL_parsing_' + str(epoch) + '' + '.pth')
            torch.save(
                {
                    'state_dict': parsingnet.state_dict(),
                    'optimizer': optimizer_parsing.state_dict(),
                    'current_epoch': epoch
                }, snapshot_name_parsing)

    end = timeit.default_timer()
    print(end - start, 'seconds')
Пример #4
0
def main():
    """Create the model and start the training."""
    cycle_n = 0
    start_epoch = args.start_epoch
    writer = SummaryWriter(osp.join(args.snapshot_dir, TIMESTAMP))
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]
    best_f1 = 0

    torch.cuda.set_device(args.local_rank)

    try:
        world_size = int(os.environ['WORLD_SIZE'])
        distributed = world_size > 1
    except:
        distributed = False
        world_size = 1
    if distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method='env://')
    rank = 0 if not distributed else dist.get_rank()

    log_file = args.snapshot_dir + '/' + TIMESTAMP + 'output.log'
    logger = get_root_logger(log_file=log_file, log_level='INFO')
    logger.info(f'Distributed training: {distributed}')

    cudnn.enabled = True
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True

    if distributed:
        model = dml_csr.DML_CSR(args.num_classes)
        schp_model = dml_csr.DML_CSR(args.num_classes)
    else:
        model = dml_csr.DML_CSR(args.num_classes, InPlaceABN)
        schp_model = dml_csr.DML_CSR(args.num_classes, InPlaceABN)

    if args.restore_from is not None:
        print('Resume training from {}'.format(args.restore_from))
        model.load_state_dict(torch.load(args.restore_from), True)
        start_epoch = int(float(
            args.restore_from.split('.')[0].split('_')[-1])) + 1
    else:
        resnet_params = torch.load(RESTORE_FROM)
        new_params = model.state_dict().copy()
        for i in resnet_params:
            i_parts = i.split('.')
            if not i_parts[0] == 'fc':
                new_params['.'.join(i_parts[0:])] = resnet_params[i]
        model.load_state_dict(new_params)
    model.cuda()

    args.schp_restore = osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth')
    if os.path.exists(args.schp_restore):
        print('Resume schp checkpoint from {}'.format(args.schp_restore))
        schp_model.load_state_dict(torch.load(args.schp_restore), True)
    else:
        schp_resnet_params = torch.load(RESTORE_FROM)
        schp_new_params = schp_model.state_dict().copy()
        for i in schp_resnet_params:
            i_parts = i.split('.')
            if not i_parts[0] == 'fc':
                schp_new_params['.'.join(i_parts[0:])] = schp_resnet_params[i]
        schp_model.load_state_dict(schp_new_params)
    schp_model.cuda()

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
        schp_model = torch.nn.parallel.DistributedDataParallel(
            schp_model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
    else:
        model = SingleGPU(model)
        schp_model = SingleGPU(schp_model)

    criterion = Criterion(loss_weight=[1, 1, 1, 4, 1],
                          lambda_1=args.lambda_s,
                          lambda_2=args.lambda_e,
                          lambda_3=args.lambda_c,
                          num_classes=args.num_classes)
    criterion.cuda()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    train_dataset = FaceDataSet(args.data_dir,
                                args.train_dataset,
                                crop_size=input_size,
                                transform=transform)
    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=2,
                                  pin_memory=True,
                                  drop_last=True,
                                  sampler=train_sampler)

    val_dataset = datasets[str(args.model_type)](args.data_dir,
                                                 args.valid_dataset,
                                                 crop_size=input_size,
                                                 transform=transform)
    num_samples = len(val_dataset)
    valloader = data.DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                pin_memory=True,
                                drop_last=False)

    # Optimizer Initialization
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    lr_scheduler = SGDRScheduler(optimizer,
                                 total_epoch=args.epochs,
                                 eta_min=args.learning_rate / 100,
                                 warmup_epoch=10,
                                 start_cyclical=args.schp_start,
                                 cyclical_base_lr=args.learning_rate / 2,
                                 cyclical_epoch=args.cycle_epochs)

    optimizer.zero_grad()

    total_iters = args.epochs * len(trainloader)
    start = timeit.default_timer()
    for epoch in range(start_epoch, args.epochs):
        model.train()
        if distributed:
            train_sampler.set_epoch(epoch)
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch

            if epoch < args.schp_start:
                lr = adjust_learning_rate(optimizer, i_iter, total_iters)
            else:
                lr = lr_scheduler.get_lr()[0]

            images, labels, edges, semantic_edges, _ = batch
            labels = labels.long().cuda(non_blocking=True)
            edges = edges.long().cuda(non_blocking=True)
            semantic_edges = semantic_edges.long().cuda(non_blocking=True)

            preds = model(images)

            if cycle_n >= 1:
                with torch.no_grad():
                    soft_preds, soft_edges, soft_semantic_edges = schp_model(
                        images)
            else:
                soft_preds = None
                soft_edges = None
                soft_semantic_edges = None

            loss = criterion(preds, [
                labels, edges, semantic_edges, soft_preds, soft_edges,
                soft_semantic_edges
            ], cycle_n)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            with torch.no_grad():
                loss = loss.detach() * labels.shape[0]
                count = labels.new_tensor([labels.shape[0]], dtype=torch.long)
                if dist.is_initialized():
                    dist.all_reduce(count, dist.ReduceOp.SUM)
                    dist.all_reduce(loss, dist.ReduceOp.SUM)
                loss /= count.item()

            if not dist.is_initialized() or dist.get_rank() == 0:
                if i_iter % 50 == 0:
                    writer.add_scalar('learning_rate', lr, i_iter)
                    writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

                if i_iter % 500 == 0:
                    images_inv = inv_preprocess(images, args.save_num_images)
                    labels_colors = decode_parsing(labels,
                                                   args.save_num_images,
                                                   args.num_classes,
                                                   is_pred=False)
                    edges_colors = decode_parsing(edges,
                                                  args.save_num_images,
                                                  2,
                                                  is_pred=False)
                    semantic_edges_colors = decode_parsing(
                        semantic_edges,
                        args.save_num_images,
                        args.num_classes,
                        is_pred=False)

                    if isinstance(preds, list):
                        preds = preds[0]
                    preds_colors = decode_parsing(preds[0],
                                                  args.save_num_images,
                                                  args.num_classes,
                                                  is_pred=True)
                    pred_edges = decode_parsing(preds[1],
                                                args.save_num_images,
                                                2,
                                                is_pred=True)
                    pred_semantic_edges_colors = decode_parsing(
                        preds[2],
                        args.save_num_images,
                        args.num_classes,
                        is_pred=True)

                    img = vutils.make_grid(images_inv,
                                           normalize=False,
                                           scale_each=True)
                    lab = vutils.make_grid(labels_colors,
                                           normalize=False,
                                           scale_each=True)
                    pred = vutils.make_grid(preds_colors,
                                            normalize=False,
                                            scale_each=True)
                    edge = vutils.make_grid(edges_colors,
                                            normalize=False,
                                            scale_each=True)
                    pred_edge = vutils.make_grid(pred_edges,
                                                 normalize=False,
                                                 scale_each=True)
                    pred_semantic_edges = vutils.make_grid(
                        pred_semantic_edges_colors,
                        normalize=False,
                        scale_each=True)

                    writer.add_image('Images/', img, i_iter)
                    writer.add_image('Labels/', lab, i_iter)
                    writer.add_image('Preds/', pred, i_iter)
                    writer.add_image('Edge/', edge, i_iter)
                    writer.add_image('Pred_edge/', pred_edge, i_iter)

                cur_loss = loss.data.cpu().numpy()
                logger.info(
                    f'iter = {i_iter} of {total_iters} completed, loss = {cur_loss}, lr = {lr}'
                )

        if (epoch + 1) % (args.eval_epochs) == 0:
            parsing_preds, scales, centers = valid(model, valloader,
                                                   input_size, num_samples)
            mIoU, f1 = compute_mean_ioU(parsing_preds, scales, centers,
                                        args.num_classes, args.data_dir,
                                        input_size, args.valid_dataset, True)

            if not dist.is_initialized() or dist.get_rank() == 0:
                torch.save(
                    model.module.state_dict(),
                    osp.join(args.snapshot_dir, TIMESTAMP,
                             'checkpoint_{}.pth'.format(epoch + 1)))
                if 'Helen' in args.data_dir:
                    if f1['overall'] > best_f1:
                        torch.save(
                            model.module.state_dict(),
                            osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth'))
                        best_f1 = f1['overall']
                else:
                    if f1['Mean_F1'] > best_f1:
                        torch.save(
                            model.module.state_dict(),
                            osp.join(args.snapshot_dir, TIMESTAMP, 'best.pth'))
                        best_f1 = f1['Mean_F1']

            writer.add_scalars('mIoU', mIoU, epoch)
            writer.add_scalars('f1', f1, epoch)
            logger.info(
                f'mIoU = {mIoU}, and f1 = {f1} of epoch = {epoch}, util now, best_f1 = {best_f1}'
            )

            if (epoch + 1) >= args.schp_start and (
                    epoch + 1 - args.schp_start) % args.cycle_epochs == 0:
                logger.info(f'Self-correction cycle number {cycle_n}')
                schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1))
                cycle_n += 1
                schp.bn_re_estimate(trainloader, schp_model)
                parsing_preds, scales, centers = valid(schp_model, valloader,
                                                       input_size, num_samples)
                mIoU, f1 = compute_mean_ioU(parsing_preds, scales, centers,
                                            args.num_classes, args.data_dir,
                                            input_size, args.valid_dataset,
                                            True)

                if not dist.is_initialized() or dist.get_rank() == 0:
                    torch.save(
                        schp_model.module.state_dict(),
                        osp.join(args.snapshot_dir, TIMESTAMP,
                                 'schp_{}_checkpoint.pth'.format(cycle_n)))

                    if 'Helen' in args.data_dir:
                        if f1['overall'] > best_f1:
                            torch.save(
                                schp_model.module.state_dict(),
                                osp.join(args.snapshot_dir, TIMESTAMP,
                                         'best.pth'))
                            best_f1 = f1['overall']
                    else:
                        if f1['Mean_F1'] > best_f1:
                            torch.save(
                                schp_model.module.state_dict(),
                                osp.join(args.snapshot_dir, TIMESTAMP,
                                         'best.pth'))
                            best_f1 = f1['Mean_F1']
                writer.add_scalars('mIoU', mIoU, epoch)
                writer.add_scalars('f1', f1, epoch)
                logger.info(
                    f'mIoU = {mIoU}, and f1 = {f1} of epoch = {epoch}, util now, best_f1 = {best_f1}'
                )

            torch.cuda.empty_cache()
            end = timeit.default_timer()
            print('epoch = {} of {} completed using {} s'.format(
                epoch, args.epochs, (end - start) / (epoch - start_epoch + 1)))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Пример #5
0
def valid(model, valloader, input_size, num_samples, gpus):
    model.eval()

    parsing_preds = np.zeros((num_samples, input_size[0], input_size[1]),
                             dtype=np.uint8)

    scales = np.zeros((num_samples, 2), dtype=np.float32)
    centers = np.zeros((num_samples, 2), dtype=np.int32)

    idx = 0
    interp = torch.nn.Upsample(size=(input_size[0], input_size[1]), mode='bilinear', align_corners=True)
    interp_1 = torch.nn.Upsample(size=(384, 384), mode='bilinear', align_corners=True)
    with torch.no_grad():
        for index, batch in enumerate(valloader):
            image, label_parsing, label_r0, label_r1, label_r2, label_r3, label_l0, label_l1, label_l2, label_l3, label_l4, label_l5, label_edge, meta = batch
            num_images = image.size(0)
            if index % 10 == 0:
                print('%d  processd' % (index * num_images))

            c = meta['center'].numpy()
            s = meta['scale'].numpy()
            scales[idx:idx + num_images, :] = s[:, :]
            centers[idx:idx + num_images, :] = c[:, :]

            outputs = model(image.cuda())
            if gpus > 1:
                for output in outputs:
                    parsing = output[0][-1]
                    nums = len(parsing)
                    parsing = interp(parsing).data.cpu().numpy()
                    parsing = parsing.transpose(0, 2, 3, 1)  # NCHW NHWC
                    parsing_preds[idx:idx + nums, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8)

                    idx += nums
            else:
                #gt = torch.from_numpy(parsing_anno)
                gt_parsing_colors = decode_parsing(label_parsing, 2, 20, False)
                gt_r0_colors = decode_parsing(label_r0, 2, 20, False)
                gt_r1_colors = decode_parsing(label_r1, 2, 20, False)
                gt_r2_colors = decode_parsing(label_r2, 2, 20, False)
                gt_r3_colors = decode_parsing(label_r3, 2, 20, False)
                #np.set_printoptions(threshold=np.inf)
                #print(label_l0.numpy())
                gt_l0_colors = decode_parsing(label_l0, 2, 20, False)
                gt_l1_colors = decode_parsing(label_l1, 2, 20, False)
                gt_l2_colors = decode_parsing(label_l2, 2, 20, False)
                gt_l3_colors = decode_parsing(label_l3, 2, 20, False)
                gt_l4_colors = decode_parsing(label_l4, 2, 20, False)
                gt_l5_colors = decode_parsing(label_l5, 2, 20, False)
                for i in range(2):
                    scipy.misc.toimage(gt_parsing_colors[i]).save("./pics/{}_{}_gt.png".format(index, i))
                    scipy.misc.toimage(gt_r0_colors[i]).save("./pics/{}_{}_gt_r0.png".format(index, i))
                    scipy.misc.toimage(gt_r1_colors[i]).save("./pics/{}_{}_gt_r1.png".format(index, i))
                    scipy.misc.toimage(gt_r2_colors[i]).save("./pics/{}_{}_gt_r2.png".format(index, i))
                    scipy.misc.toimage(gt_r3_colors[i]).save("./pics/{}_{}_gt_r3.png".format(index, i))
                    scipy.misc.toimage(gt_l0_colors[i]).save("./pics/{}_{}_gt_l0.png".format(index, i))
                    scipy.misc.toimage(gt_l1_colors[i]).save("./pics/{}_{}_gt_l1.png".format(index, i))
                    scipy.misc.toimage(gt_l2_colors[i]).save("./pics/{}_{}_gt_l2.png".format(index, i))
                    scipy.misc.toimage(gt_l3_colors[i]).save("./pics/{}_{}_gt_l3.png".format(index, i))
                    scipy.misc.toimage(gt_l4_colors[i]).save("./pics/{}_{}_gt_l4.png".format(index, i))
                    scipy.misc.toimage(gt_l5_colors[i]).save("./pics/{}_{}_gt_l5.png".format(index, i))

                parsing = outputs[0][0]
                tmp = interp_1(parsing)
                tmp = torch.argmax(tmp, dim=1, keepdim=False)
                ignore_index = label_parsing == 255
                tmp[ignore_index] = 0
                preds_colors = decode_parsing(tmp, 2, 20, False)
                pred_r0 = outputs[1][0]
                pred_r0 = interp_1(pred_r0)
                pred_r0_colors = decode_parsing(pred_r0, 2, 20, True)
                pred_r1 = outputs[1][1]
                pred_r1 = interp_1(pred_r1)
                pred_r1_colors = decode_parsing(pred_r1, 2, 20, True)
                pred_r2 = outputs[1][2]
                pred_r2 = interp_1(pred_r2)
                pred_r2_colors = decode_parsing(pred_r2, 2, 20, True)
                pred_r3 = outputs[1][3]
                pred_r3 = interp_1(pred_r3)
                pred_r3_colors = decode_parsing(pred_r3, 2, 20, True)
                pred_l0 = outputs[2][0]
                pred_l0 = interp_1(pred_l0)
                pred_l0_colors = decode_parsing(pred_l0, 2, 20, True)
                pred_l1 = outputs[2][1]
                pred_l1 = interp_1(pred_l1)
                pred_l1_colors = decode_parsing(pred_l1, 2, 20, True)
                pred_l2 = outputs[2][2]
                pred_l2 = interp_1(pred_l2)
                pred_l2_colors = decode_parsing(pred_l2, 2, 20, True)
                pred_l3 = outputs[2][3]
                pred_l3 = interp_1(pred_l3)
                pred_l3_colors = decode_parsing(pred_l3, 2, 20, True)
                pred_l4 = outputs[2][4]
                pred_l4 = interp_1(pred_l4)
                pred_l4_colors = decode_parsing(pred_l4, 2, 20, True)
                pred_l5 = outputs[2][5]
                pred_l5 = interp_1(pred_l5)
                pred_l5_colors = decode_parsing(pred_l5, 2, 20, True)
                for i in range(2):
                    scipy.misc.toimage(preds_colors[i]).save("./pics/{}_{}_pred.png".format(index, i))
                    scipy.misc.toimage(pred_r0_colors[i]).save("./pics/{}_{}_pred_r0.png".format(index, i))
                    scipy.misc.toimage(pred_r1_colors[i]).save("./pics/{}_{}_pred_r1.png".format(index, i))
                    scipy.misc.toimage(pred_r2_colors[i]).save("./pics/{}_{}_pred_r2.png".format(index, i))
                    scipy.misc.toimage(pred_r3_colors[i]).save("./pics/{}_{}_pred_r3.png".format(index, i))
                    scipy.misc.toimage(pred_l0_colors[i]).save("./pics/{}_{}_pred_l0.png".format(index, i))
                    scipy.misc.toimage(pred_l1_colors[i]).save("./pics/{}_{}_pred_l1.png".format(index, i))
                    scipy.misc.toimage(pred_l2_colors[i]).save("./pics/{}_{}_pred_l2.png".format(index, i))
                    scipy.misc.toimage(pred_l3_colors[i]).save("./pics/{}_{}_pred_l3.png".format(index, i))
                    scipy.misc.toimage(pred_l4_colors[i]).save("./pics/{}_{}_pred_l4.png".format(index, i))
                    scipy.misc.toimage(pred_l5_colors[i]).save("./pics/{}_{}_pred_l5.png".format(index, i))
                parsing = interp(parsing).data.cpu().numpy()
                parsing = parsing.transpose(0, 2, 3, 1)  # NCHW NHWC
                parsing_preds[idx:idx + num_images, :, :] = np.asarray(np.argmax(parsing, axis=3), dtype=np.uint8)

                idx += num_images

    parsing_preds = parsing_preds[:num_samples, :, :]


    return parsing_preds, scales, centers
def main():
    """Create the model and start the training."""

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    writer = SummaryWriter(args.snapshot_dir)
    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]

    cudnn.enabled = True
    # cudnn related setting
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True
 

    deeplab = Res_Deeplab(num_classes=args.num_classes)
    print(type(deeplab))
    

    # dump_input = torch.rand((args.batch_size, 3, input_size[0], input_size[1]))
    # writer.add_graph(deeplab.cuda(), dump_input.cuda(), verbose=False)


    """
    HOW DOES IT LOAD ONLY RESNET101 AND NOT THE RSTE OF THE NET ?
    """
    # UNCOMMENT THE FOLLOWING COMMENTARY TO INITIALYZE THE WEIGHTS
    
    # Load resnet101 weights trained on imagenet and copy it in new_params
    saved_state_dict = torch.load(args.restore_from)
    new_params = deeplab.state_dict().copy()

    # CHECK IF WEIGHTS BELONG OR NOT TO THE MODEL
    # belongs = 0
    # doesnt_b = 0
    # for key in saved_state_dict:
    #     if key in new_params:
    #         belongs+=1 
    #         print('key=', key)
    #     else:
    #         doesnt_b+=1
    #         # print('key=', key)
    # print('belongs = ', belongs, 'doesnt_b=', doesnt_b)
    # print('res101 len',len(saved_state_dict))
    # print('new param len',len(new_params))


    for i in saved_state_dict:
        i_parts = i.split('.')
        # print('i_parts:', i_parts)
        # exp : i_parts: ['layer2', '3', 'bn2', 'running_mean']

        # The deeplab weight modules  have diff name than args.restore_from weight modules
        if i_parts[0] == 'module' and not i_parts[1] == 'fc' :
            if new_params['.'.join(i_parts[1:])].size() == saved_state_dict[i].size():
                new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
        else:
            if not i_parts[0] == 'fc':
                if new_params['.'.join(i_parts[0:])].size() == saved_state_dict[i].size():
                    new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
 
    deeplab.load_state_dict(new_params)
    
    # UNCOMMENT UNTIL HERE

    model = DataParallelModel(deeplab)
    model.cuda()

    criterion = CriterionAll()
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    trainloader = data.DataLoader(cartoonDataSet(args.data_dir, args.dataset, crop_size=input_size, transform=transform),
                                  batch_size=args.batch_size * len(gpus), shuffle=True, num_workers=8,
                                  pin_memory=True)

    #mIoU for Val set
    val_dataset = cartoonDataSet(args.data_dir, 'val', crop_size=input_size, transform=transform)
    numVal_samples = len(val_dataset)
    
    valloader = data.DataLoader(val_dataset, batch_size=args.batch_size * len(gpus),
                                shuffle=False, pin_memory=True)

    #mIoU for trainTest set
    trainTest_dataset = cartoonDataSet(args.data_dir, 'trainTest', crop_size=input_size, transform=transform)
    numTest_samples = len(trainTest_dataset)
    
    testloader = data.DataLoader(trainTest_dataset, batch_size=args.batch_size * len(gpus),
                                shuffle=False, pin_memory=True)


    optimizer = optim.SGD(
        model.parameters(),
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    optimizer.zero_grad()
    # valBatch_idx = 0
    total_iters = args.epochs * len(trainloader)
    for epoch in range(args.start_epoch, args.epochs):
        model.train()
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch
            lr = adjust_learning_rate(optimizer, i_iter, total_iters)
            images, labels, _, _ = batch
            labels = labels.long().cuda(non_blocking=True)
            preds = model(images)
            # print('preds size in batch', len(preds))
            # print('Size of Segmentation1 tensor output:',preds[0][0].size())
            # print('Segmentation2 tensor output:',preds[0][-1].size())
            # print('Size of Edge tensor output:',preds[1][-1].size())
            loss = criterion(preds, [labels])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i_iter % 100 == 0:
                writer.add_scalar('learning_rate', lr, i_iter)
                writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

            if i_iter % 500 == 0:
                # print('In iter%500 Size of Segmentation2 GT: ', labels.size())
                # print('In iter%500 Size of edges GT: ', edges.size())
                images_inv = inv_preprocess(images, args.save_num_images)
                # print(labels[0])
                labels_colors = decode_parsing(labels, args.save_num_images, args.num_classes, is_pred=False)
               
                # if isinstance(preds, list):
                #     print(len(preds))
                #     preds = preds[0]
                
                # val_images, _ = valloader[valBatch_idx]
                # valBatch_idx += 1
                # val_sampler = torch.utils.data.RandomSampler(val_dataset,replacement=True, num_samples=args.batch_size * len(gpus))
                # sample_valloader = data.DataLoader(val_dataset, batch_size=args.batch_size * len(gpus),
                #                 shuffle=False, sampler=val_sampler , pin_memory=True)
                # val_images, _ = sample_valloader
                # preds_val = model(val_images)

                # With multiple GPU, preds return a list, therefore we extract the tensor in the list
                if len(gpus)>1:
                    preds= preds[0]
                    # preds_val = preds_val[0]

                
                

                # print('In iter%500 Size of Segmentation2 tensor output:',preds[0][0][-1].size())
                # preds[0][-1] cause model returns [[seg1, seg2], [edge]]
                preds_colors = decode_parsing(preds[0][-1], args.save_num_images, args.num_classes, is_pred=True)
                # preds_val_colors = decode_parsing(preds_val[0][-1], args.save_num_images, args.num_classes, is_pred=True)
                # print("preds type:",type(preds)) #list
                # print("preds shape:", len(preds)) #2
                # hello = preds[0][-1]
                # print("preds type [0][-1]:",type(hello)) #<class 'torch.Tensor'>
                # print("preds len [0][-1]:", len(hello)) #12
                # print("preds len [0][-1]:", hello.shape)#torch.Size([12, 8, 96, 96])
                # print("preds color's type:",type(preds_colors))#torch.tensor
                # print("preds color's shape:",preds_colors.shape) #([2,3,96,96])

                # print('IMAGE', images_inv.size())
                img = vutils.make_grid(images_inv, normalize=False, scale_each=True)
                lab = vutils.make_grid(labels_colors, normalize=False, scale_each=True)
                pred = vutils.make_grid(preds_colors, normalize=False, scale_each=True)
                
                
                # print("preD type:",type(pred)) #<class 'torch.Tensor'>
                # print("preD len:", len(pred))# 3
                # print("preD shape:", pred.shape)#torch.Size([3, 100, 198])

                # 1=head red, 2=body green , 3=left_arm yellow, 4=right_arm blue, 5=left_leg pink
                # 6=right_leg skuBlue, 7=tail grey

                writer.add_image('Images/', img, i_iter)
                writer.add_image('Labels/', lab, i_iter)
                writer.add_image('Preds/', pred, i_iter)
                
               
            print('iter = {} of {} completed, loss = {}'.format(i_iter, total_iters, loss.data.cpu().numpy()))
        
        print('end epoch:', epoch)
        
        if epoch%99 == 0:
            torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'DFPnet_epoch_' + str(epoch) + '.pth'))
        
        if epoch%5 == 0 and epoch<500:
            # mIou for Val set
            parsing_preds, scales, centers = valid(model, valloader, input_size,  numVal_samples, len(gpus))
            '''
            Insert a sample of prediction of a val image on tensorboard
            '''
            # generqte a rand number between len(parsing_preds)
            sample = random.randint(0, len(parsing_preds)-1)
            
            #loader resize and convert to tensor the image
            loader = transforms.Compose([
                transforms.Resize(input_size),
                transforms.ToTensor()
            ])

            # get val segmentation path and open the file
            list_path = os.path.join(args.data_dir, 'val' + '_id.txt')
            val_id = [i_id.strip() for i_id in open(list_path)]
            gt_path = os.path.join(args.data_dir, 'val' + '_segmentations', val_id[sample] + '.png')
            gt =Image.open(gt_path)
            gt = loader(gt)
            #put gt back from 0 to 255
            gt = (gt*255).int()
            # convert pred from ndarray to PIL image then to tensor
            display_preds = Image.fromarray(parsing_preds[sample])
            tensor_display_preds = transforms.ToTensor()(display_preds)
            #put gt back from 0 to 255
            tensor_display_preds = (tensor_display_preds*255).int()
            # color them 
            val_preds_colors = decode_parsing(tensor_display_preds, num_images=1, num_classes=args.num_classes, is_pred=False)
            gt_color = decode_parsing(gt, num_images=1, num_classes=args.num_classes, is_pred=False)
            # put in grid 
            pred_val = vutils.make_grid(val_preds_colors, normalize=False, scale_each=True)
            gt_val = vutils.make_grid(gt_color, normalize=False, scale_each=True)
            writer.add_image('Preds_val/', pred_val, epoch)
            writer.add_image('Gt_val/', gt_val, epoch)

            mIoUval = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, 'val')

            print('For val set', mIoUval)
            writer.add_scalars('mIoUval', mIoUval, epoch)

            # mIou for trainTest set
            parsing_preds, scales, centers = valid(model, testloader, input_size,  numTest_samples, len(gpus))

            mIoUtest = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, 'trainTest')

            print('For trainTest set', mIoUtest)
            writer.add_scalars('mIoUtest', mIoUtest, epoch)

        else:
            if epoch%20 == 0 and epoch>=500:
                # mIou for Val set
                parsing_preds, scales, centers = valid(model, valloader, input_size,  numVal_samples, len(gpus))
                '''
                Insert a sample of prediction of a val image on tensorboard
                '''
                # generqte a rand number between len(parsing_preds)
                sample = random.randint(0, len(parsing_preds)-1)
                
                #loader resize and convert to tensor the image
                loader = transforms.Compose([
                    transforms.Resize(input_size),
                    transforms.ToTensor()
                ])

                # get val segmentation path and open the file
                list_path = os.path.join(args.data_dir, 'val' + '_id.txt')
                val_id = [i_id.strip() for i_id in open(list_path)]
                gt_path = os.path.join(args.data_dir, 'val' + '_segmentations', val_id[sample] + '.png')
                gt =Image.open(gt_path)
                gt = loader(gt)
                #put gt back from 0 to 255
                gt = (gt*255).int()
                # convert pred from ndarray to PIL image then to tensor
                display_preds = Image.fromarray(parsing_preds[sample])
                tensor_display_preds = transforms.ToTensor()(display_preds)
                #put gt back from 0 to 255
                tensor_display_preds = (tensor_display_preds*255).int()
                # color them 
                val_preds_colors = decode_parsing(tensor_display_preds, num_images=1, num_classes=args.num_classes, is_pred=False)
                gt_color = decode_parsing(gt, num_images=1, num_classes=args.num_classes, is_pred=False)
                # put in grid 
                pred_val = vutils.make_grid(val_preds_colors, normalize=False, scale_each=True)
                gt_val = vutils.make_grid(gt_color, normalize=False, scale_each=True)
                writer.add_image('Preds_val/', pred_val, epoch)
                writer.add_image('Gt_val/', gt_val, epoch)

                mIoUval = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, 'val')

                print('For val set', mIoUval)
                writer.add_scalars('mIoUval', mIoUval, epoch)

                # mIou for trainTest set
                parsing_preds, scales, centers = valid(model, testloader, input_size,  numTest_samples, len(gpus))

                mIoUtest = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size, 'trainTest')

                print('For trainTest set', mIoUtest)
                writer.add_scalars('mIoUtest', mIoUtest, epoch)

    end = timeit.default_timer()
    print(end - start, 'seconds')