Exemplo n.º 1
0
def main(cfgs):
    Logger.init(**cfgs['logger'])

    local_rank = cfgs['local_rank']
    world_size = int(os.environ['WORLD_SIZE'])
    Log.info('rank: {}, world_size: {}'.format(local_rank, world_size))

    log_dir = cfgs['log_dir']
    pth_dir = cfgs['pth_dir']
    if local_rank == 0:
        assure_dir(log_dir)
        assure_dir(pth_dir)

    aux_config = cfgs.get('auxiliary', None)
    network = ModuleBuilder(cfgs['network'], aux_config).cuda()
    criterion = build_criterion(cfgs['criterion'], aux_config).cuda()
    optimizer = optim.SGD(network.parameters(), **cfgs['optimizer'])
    scheduler = PolyLRScheduler(optimizer, **cfgs['scheduler'])

    dataset = build_dataset(**cfgs['dataset'], **cfgs['transforms'])
    sampler = DistributedSampler4Iter(dataset,
                                      world_size=world_size,
                                      rank=local_rank,
                                      **cfgs['sampler'])
    train_loader = DataLoader(dataset, sampler=sampler, **cfgs['loader'])

    cudnn.benchmark = True
    torch.manual_seed(666)
    torch.cuda.manual_seed(666)
    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')

    model = DistributedDataParallel(network)
    model = apex.parallel.convert_syncbn_model(model)

    torch.cuda.empty_cache()
    train(local_rank, world_size, pth_dir, cfgs['frequency'], criterion,
          train_loader, model, optimizer, scheduler)
Exemplo n.º 2
0
def main():

    # make save dir
    if args.local_rank == 0:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
    # launch the logger
    Log.init(
        log_level=args.log_level,
        log_file=osp.join(args.save_dir, args.log_file),
        log_format=args.log_format,
        rewrite=args.rewrite,
        stdout_level=args.stdout_level
    )
    # RGB or BGR input(RGB input for ImageNet pretrained models while BGR input for caffe pretrained models)
    if args.rgb:
        IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32)
        IMG_VARS = np.array((0.229, 0.224, 0.225), dtype=np.float32)
    else:
        IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434), dtype=np.float32)
        IMG_VARS = np.array((1, 1, 1), dtype=np.float32)

    # set models
    import libs.models as models
    deeplab = models.__dict__[args.arch](num_classes=args.num_classes, data_set=args.data_set)
    if args.restore_from is not None:
        saved_state_dict = torch.load(args.restore_from, map_location=torch.device('cpu'))
        new_params = deeplab.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not i_parts[0] == 'fc':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
        Log.info("load pretrined models")
        if deeplab.backbone is not None:
            deeplab.backbone.load_state_dict(new_params, strict=False)
        else:
            deeplab.load_state_dict(new_params, strict=False)
    else:
        Log.info("train from stracth")


    args.world_size = 1

    if 'WORLD_SIZE' in os.environ and args.apex:
        args.apex = int(os.environ['WORLD_SIZE']) > 1
        args.world_size = int(os.environ['WORLD_SIZE'])
        print("Total world size: ", int(os.environ['WORLD_SIZE']))

    if not args.gpu == None:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    h, w = args.input_size, args.input_size
    input_size = (h, w)


     # Set the device according to local_rank.
    torch.cuda.set_device(args.local_rank)
    Log.info("Local Rank: {}".format(args.local_rank))
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='env://')
    # set optimizer
    optimizer = optim.SGD(
        [{'params': filter(lambda p: p.requires_grad, deeplab.parameters()), 'lr': args.learning_rate}],
        lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer.zero_grad()

    # set on cuda
    deeplab.cuda()

    # models transformation
    model = DistributedDataParallel(deeplab)
    model = apex.parallel.convert_syncbn_model(model)
    model.train()
    model.float()
    model.cuda()

    # set loss function
    if args.ohem:
        criterion = CriterionOhemDSN(thresh=args.ohem_thres, min_kept=args.ohem_keep)  # OHEM CrossEntrop
        if "ic" in args.arch:
            criterion = CriterionICNet(thresh=args.ohem_thres, min_kept=args.ohem_keep)
        if "dfa" in args.arch:
            criterion = CriterionDFANet(thresh=args.ohem_thres, min_kept=args.ohem_keep)
    else:
        criterion = CriterionDSN()  # CrossEntropy
    criterion.cuda()

    cudnn.benchmark = True

    if args.world_size == 1:
        print(model)

    # this is a little different from mul-gpu traning setting in distributed training
    # because each trainloader is a process that sample from the dataset class.
    batch_size = args.gpu_num * args.batch_size_per_gpu
    max_iters = args.num_steps * batch_size / args.gpu_num
    # set data loader
    data_set = Cityscapes(args.data_dir, args.data_list, max_iters=max_iters, crop_size=input_size,
                  scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN,vars=IMG_VARS, RGB= args.rgb)

    trainloader = data.DataLoader(
        data_set,
        batch_size=args.batch_size_per_gpu, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    print("trainloader", len(trainloader))

    torch.cuda.empty_cache()

    # start training:
    for i_iter, batch in enumerate(trainloader):
        images, labels = batch
        images = images.cuda()
        labels = labels.long().cuda()
        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, args, i_iter, len(trainloader))
        preds = model(images)

        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()
        reduce_loss = all_reduce_tensor(loss,
                                        world_size=args.gpu_num)
        if args.local_rank == 0:
            Log.info('iter = {} of {} completed, lr={}, loss = {}'.format(i_iter,
                                                                      len(trainloader), lr, reduce_loss.data.cpu().numpy()))
            if i_iter % args.save_pred_every == 0 and i_iter > args.save_start:
                print('save models ...')
                torch.save(deeplab.state_dict(), osp.join(args.save_dir, str(args.arch) + str(i_iter) + '.pth'))

    end = timeit.default_timer()

    if args.local_rank == 0:
        Log.info("Training cost: "+ str(end - start) + 'seconds')
        Log.info("Save final models")
        torch.save(deeplab.state_dict(), osp.join(args.save_dir, str(args.arch) + '_final' + '.pth'))
Exemplo n.º 3
0
def main():

    settings_print_interval = 1  # How often to print loss and other info
    settings_batch_size = 4 # Batch size 80   default 64
    settings_num_workers = 16 # Number of workers for image loading
    settings_normalize_mean = [0.485, 0.456, 0.406]  # Normalize mean (default pytorch ImageNet values)
    settings_normalize_std = [0.229, 0.224, 0.225]  # Normalize std (default pytorch ImageNet values)
    settings_search_area_factor = 4.0  # Image patch size relative to target size
    settings_feature_sz = 24  # Size of feature map
    settings_output_sz = settings_feature_sz * 16  # Size of input image patches 24*16
    settings_segm_use_distance = True

    # Settings for the image sample and proposal generation
    settings_center_jitter_factor = {'train': 0, 'test1': 1.5, 'test2': 1.5}
    settings_scale_jitter_factor = {'train': 0, 'test1': 0.25, 'test2': 0.25}
####################################################################################################
    start_epoch = 0
    random.seed(0)

    args = parse_args()
    # Use GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if args.gpu != '' else str(opt.gpu_id)
    use_gpu = torch.cuda.is_available() and (args.gpu != '' or int(opt.gpu_id)) >= 0
    gpu_ids = [int(val) for val in args.gpu.split(',')]

    if not os.path.isdir(opt.checkpoint):
        os.makedirs(opt.checkpoint)

    # Data
    print('==> Preparing dataset')

    input_size = opt.input_size

    train_transformer = TrainTransform(size=input_size)
    #train_transformer = TrainTransform_Noresize()
    test_transformer = TestTransform(size=input_size)

    try:
        if isinstance(opt.trainset, list):
            datalist = []
            for dataset, freq, max_skip in zip(opt.trainset, opt.datafreq, opt.max_skip):
                ds = DATA_CONTAINER[dataset](
                    train=True, 
                    sampled_frames=opt.sampled_frames, 
                    transform=train_transformer, 
                    max_skip=max_skip, 
                    samples_per_video=opt.samples_per_video
                )
                datalist += [ds] * freq

            trainset = data.ConcatDataset(datalist)

        else:
            max_skip = opt.max_skip[0] if isinstance(opt.max_skip, list) else opt.max_skip
            trainset = DATA_CONTAINER[opt.trainset](
                train=True, 
                sampled_frames=opt.sampled_frames, 
                transform=train_transformer, 
                max_skip=max_skip, 
                samples_per_video=opt.samples_per_video
                )
    except KeyError as ke:
        print('[ERROR] invalide dataset name is encountered. The current acceptable datasets are:')
        print(list(DATA_CONTAINER.keys()))
        exit()

    testset = DATA_CONTAINER[opt.valset](
        train=False,
        transform=test_transformer,
        samples_per_video=1
        )

    trainloader = data.DataLoader(trainset, batch_size=opt.train_batch, shuffle=True, num_workers=opt.workers,
                                  collate_fn=multibatch_collate_fn, drop_last=True)

    testloader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=opt.workers,
                                 collate_fn=multibatch_collate_fn)


    #########################################################################################
    vos_train = Vos(split='train')
    transform_train = torchvision.transforms.Compose([dltransforms.ToTensorAndJitter(0.2),
                                                      torchvision.transforms.Normalize(mean=settings_normalize_mean,
                                                                                       std=settings_normalize_std)])
    data_processing_train = segm_processing.SegmProcessing(search_area_factor=settings_search_area_factor,
                                                           output_sz=settings_output_sz,
                                                           center_jitter_factor=settings_center_jitter_factor,
                                                           scale_jitter_factor=settings_scale_jitter_factor,
                                                           mode='pair',
                                                           transform=transform_train,
                                                           use_distance=settings_segm_use_distance)
    dataset_train = segm_sampler.SegmSampler([vos_train], [1],
                                             samples_per_epoch=1000 * settings_batch_size * 8, max_gap=50,
                                             processing=data_processing_train)
    loader_train = LTRLoader('train', dataset_train, training=True, batch_size=settings_batch_size,
                             num_workers=settings_num_workers,
                             shuffle=True, drop_last=True, stack_dim=1)

    #########################################################################################

    # Model
    print("==> creating model")

    net = AMB(opt.keydim, opt.valdim, 'train', mode=opt.mode, iou_threshold=opt.iou_threshold)
    print('    Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0))


    net.eval()

    if use_gpu:
        net = net.cuda()

    assert opt.train_batch % len(gpu_ids) == 0
    net = nn.DataParallel(net, device_ids=gpu_ids, dim=0)

    # set training parameters
    #for p in net.parameters():
      #  p.requires_grad = True
    for name, param in net.named_parameters():
        #print(name)
        if 'Encoder' in name:
            param.requires_grad = False  # 冻结 backbone 梯度
        else:
            param.requires_grad = True

    criterion = None
    celoss = cross_entropy_loss

    if opt.loss == 'ce':
        criterion = celoss
    elif opt.loss == 'iou':
        criterion = mask_iou_loss
    elif opt.loss == 'both':
        criterion = lambda pred, target, obj: celoss(pred, target, obj) + mask_iou_loss(pred, target, obj)
    else:
        raise TypeError('unknown training loss %s' % opt.loss)

    optimizer = None
    
    if opt.solver == 'sgd':

        optimizer = optim.SGD(net.parameters(), lr=opt.learning_rate,
                        momentum=opt.momentum[0], weight_decay=opt.weight_decay)
    elif opt.solver == 'adam':

        optimizer = optim.Adam(net.parameters(), lr=opt.learning_rate,
                        betas=opt.momentum, weight_decay=opt.weight_decay)
    else:
        raise TypeError('unkown solver type %s' % opt.solver)

    # Resume
    title = 'Appearance Memory Bank'
    minloss = float('inf')

    opt.checkpoint = osp.join(osp.join(opt.checkpoint, opt.valset))
    if not osp.exists(opt.checkpoint):
        os.mkdir(opt.checkpoint)

    if opt.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint {}'.format(opt.resume))
        assert os.path.isfile(opt.resume), 'Error: no checkpoint directory found!'
        # opt.checkpoint = os.path.dirname(opt.resume)
        checkpoint = torch.load(opt.resume)
        minloss = checkpoint['minloss']
        start_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        skips = checkpoint['max_skip']
        
        try:
            if isinstance(skips, list):
                for idx, skip in enumerate(skips):
                    trainloader.dataset.datasets[idx].set_max_skip(skip)
            else:
                trainloader.dataset.set_max_skip(skip)
        except:
            print('[Warning] Initializing max skip fail')

        logger = Logger(os.path.join(opt.checkpoint, opt.mode+'_log.txt'), resume=True)
    else:
        if opt.initial:
            print('==> Initialize model with weight file {}'.format(opt.initial))
            weight = torch.load(opt.initial)
            if isinstance(weight, OrderedDict):
                net.module.load_param(weight)
            else:
                net.module.load_param(weight['state_dict'])

        logger = Logger(os.path.join(opt.checkpoint, opt.mode+'_log.txt'), resume=False)
        start_epoch = 0

    logger.set_items(['Epoch', 'LR', 'Train Loss'])

    # Train and val
    for epoch in range(start_epoch):
        adjust_learning_rate(optimizer, epoch, opt)

    for epoch in range(start_epoch, opt.epochs):

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, opt.epochs, opt.learning_rate))
        adjust_learning_rate(optimizer, epoch, opt)

        net.module.phase = 'train'
        train_loss = train(loader_train, # loader_train trainloader
                           model=net,
                           criterion=criterion,
                           optimizer=optimizer,
                           epoch=epoch,
                           use_cuda=use_gpu,
                           iter_size=opt.iter_size,
                           mode=opt.mode,
                           threshold=opt.iou_threshold)

        if (epoch + 1) % opt.epoch_per_test == 0:
            net.module.phase = 'test'
            test_loss = test(testloader,
                            model=net.module,
                            criterion=criterion,
                            epoch=epoch,
                            use_cuda=use_gpu)

        # append logger file
        logger.log(epoch+1, opt.learning_rate, train_loss)

        # adjust max skip
        if (epoch + 1) % opt.epochs_per_increment == 0:
            if isinstance(trainloader.dataset, data.ConcatDataset):
                for dataset in trainloader.dataset.datasets:
                    dataset.increase_max_skip()
            else:
                trainloader.dataset.increase_max_skip()

        # save model
        is_best = train_loss <= minloss
        minloss = min(minloss, train_loss)
        skips = [ds.max_skip for ds in trainloader.dataset.datasets] \
                if isinstance(trainloader.dataset, data.ConcatDataset) \
                 else trainloader.dataset.max_skip

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': net.state_dict(),
            'loss': train_loss,
            'minloss': minloss,
            'optimizer': optimizer.state_dict(),
            'max_skip': skips,
        }, epoch + 1, is_best, checkpoint=opt.checkpoint, filename=opt.mode)

    logger.close()

    print('minimum loss:')
    print(minloss)
Exemplo n.º 4
0
def main():

    start_epoch = 0

    args = parse_args()
    # Use GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu if args.gpu != '' else str(
        opt.gpu_id)
    use_gpu = torch.cuda.is_available() and (args.gpu != ''
                                             or int(opt.gpu_id)) >= 0

    if not os.path.isdir(opt.checkpoint):
        os.makedirs(opt.checkpoint)

    # Data
    print('==> Preparing dataset')

    input_dim = opt.input_size

    train_transformer = TrainTransform(size=input_dim)
    test_transformer = TestTransform(size=input_dim)

    try:
        if isinstance(opt.trainset, list):
            datalist = []
            for dataset, freq, max_skip in zip(opt.trainset, opt.datafreq,
                                               opt.max_skip):
                ds = DATA_CONTAINER[dataset](
                    train=True,
                    sampled_frames=opt.sampled_frames,
                    transform=train_transformer,
                    max_skip=max_skip,
                    samples_per_video=opt.samples_per_video)
                datalist += [ds] * freq

            trainset = data.ConcatDataset(datalist)

        else:
            max_skip = opt.max_skip[0] if isinstance(opt.max_skip,
                                                     list) else opt.max_skip
            trainset = DATA_CONTAINER[opt.trainset](
                train=True,
                sampled_frames=opt.sampled_frames,
                transform=train_transformer,
                max_skip=max_skip,
                samples_per_video=opt.samples_per_video)
    except KeyError as ke:
        print(
            '[ERROR] invalide dataset name is encountered. The current acceptable datasets are:'
        )
        print(list(DATA_CONTAINER.keys()))
        exit()

    testset = DATA_CONTAINER[opt.valset](train=False,
                                         transform=test_transformer,
                                         samples_per_video=1)

    trainloader = data.DataLoader(trainset,
                                  batch_size=opt.train_batch,
                                  shuffle=True,
                                  num_workers=opt.workers,
                                  collate_fn=multibatch_collate_fn)

    testloader = data.DataLoader(testset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=opt.workers,
                                 collate_fn=multibatch_collate_fn)
    # Model
    print("==> creating model")

    net = STAN(opt.keydim, opt.valdim)
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in net.parameters()) / 1000000.0))

    net.eval()
    if use_gpu:
        net = net.cuda()

    # set training parameters
    for p in net.parameters():
        p.requires_grad = True

    criterion = None
    celoss = cross_entropy_loss

    if opt.loss == 'ce':
        criterion = celoss
    elif opt.loss == 'iou':
        criterion = mask_iou_loss
    elif opt.loss == 'both':
        criterion = lambda pred, target, obj: celoss(
            pred, target, obj) + mask_iou_loss(pred, target, obj)
    else:
        raise TypeError('unknown training loss %s' % opt.loss)

    optimizer = None

    if opt.solver == 'sgd':

        optimizer = optim.SGD(net.parameters(),
                              lr=opt.learning_rate,
                              momentum=opt.momentum[0],
                              weight_decay=opt.weight_decay)
    elif opt.solver == 'adam':

        optimizer = optim.Adam(net.parameters(),
                               lr=opt.learning_rate,
                               betas=opt.momentum,
                               weight_decay=opt.weight_decay)
    else:
        raise TypeError('unkown solver type %s' % opt.solver)

    # Resume
    title = 'STAN'
    minloss = float('inf')

    opt.checkpoint = osp.join(osp.join(opt.checkpoint, opt.valset))
    if not osp.exists(opt.checkpoint):
        os.mkdir(opt.checkpoint)

    if opt.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint {}'.format(opt.resume))
        assert os.path.isfile(
            opt.resume), 'Error: no checkpoint directory found!'
        # opt.checkpoint = os.path.dirname(opt.resume)
        checkpoint = torch.load(opt.resume)
        minloss = checkpoint['minloss']
        start_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        skips = checkpoint['max_skip']

        try:
            if isinstance(skips, list):
                for idx, skip in enumerate(skips):
                    trainloader.dataset.datasets[idx].set_max_skip(skip)
            else:
                trainloader.dataset.set_max_skip(skip)
        except:
            print('[Warning] Initializing max skip fail')

        logger = Logger(os.path.join(opt.checkpoint, opt.mode + '_log.txt'),
                        resume=True)
    else:
        if opt.initial:
            print('==> Initialize model with weight file {}'.format(
                opt.initial))
            weight = torch.load(opt.initial)
            if isinstance(weight, OrderedDict):
                net.load_param(weight)
            else:
                net.load_param(weight['state_dict'])

        logger = Logger(os.path.join(opt.checkpoint, opt.mode + '_log.txt'),
                        resume=False)
        start_epoch = 0

    logger.set_items(['Epoch', 'LR', 'Train Loss'])

    # Train and val
    for epoch in range(start_epoch):
        adjust_learning_rate(optimizer, epoch, opt)

    for epoch in range(start_epoch, opt.epochs):

        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, opt.epochs, opt.learning_rate))
        adjust_learning_rate(optimizer, epoch, opt)

        train_loss = train(trainloader,
                           model=net,
                           criterion=criterion,
                           optimizer=optimizer,
                           epoch=epoch,
                           use_cuda=use_gpu,
                           iter_size=opt.iter_size,
                           mode=opt.mode,
                           threshold=opt.iou_threshold)

        if (epoch + 1) % opt.epoch_per_test == 0:
            test_loss = test(testloader,
                             model=net,
                             criterion=criterion,
                             epoch=epoch,
                             use_cuda=use_gpu,
                             opt=opt)

        # append logger file
        logger.log(epoch + 1, opt.learning_rate, train_loss)

        # adjust max skip
        if (epoch + 1) % opt.epochs_per_increment == 0:
            if isinstance(trainloader.dataset, data.ConcatDataset):
                for dataset in trainloader.dataset.datasets:
                    dataset.increase_max_skip()
            else:
                trainloader.dataset.increase_max_skip()

        # save model
        is_best = train_loss <= minloss
        minloss = min(minloss, train_loss)
        skips = [ds.max_skip for ds in trainloader.dataset.datasets] \
                if isinstance(trainloader.dataset, data.ConcatDataset) \
                 else trainloader.dataset.max_skip

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': net.state_dict(),
                'loss': train_loss,
                'minloss': minloss,
                'optimizer': optimizer.state_dict(),
                'max_skip': skips,
            },
            epoch + 1,
            is_best,
            checkpoint=opt.checkpoint,
            filename=opt.mode)

    logger.close()

    print('minimum loss:')
    print(minloss)
Exemplo n.º 5
0
def main():

    # make save dir
    if args.local_rank == 0:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

    # for tensorboard logs
    tb_path = osp.join(args.save_dir, "runs")
    writer = SummaryWriter(tb_path)

    # launch the logger
    Log.init(log_level=args.log_level,
             log_file=osp.join(args.save_dir, args.log_file),
             log_format=args.log_format,
             rewrite=args.rewrite,
             stdout_level=args.stdout_level)
    # RGB or BGR input(RGB input for ImageNet pretrained models while BGR input for caffe pretrained models)
    if args.rgb:
        IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32)
        IMG_VARS = np.array((0.229, 0.224, 0.225), dtype=np.float32)
    else:
        IMG_MEAN = np.array((104.00698793, 116.66876762, 122.67891434),
                            dtype=np.float32)
        IMG_VARS = np.array((1, 1, 1), dtype=np.float32)

    # set models
    import libs.models as models
    deeplab = models.__dict__[args.arch](num_classes=args.num_classes)
    # print(deeplab)
    if args.restore_from is not None:
        print("LOADING FROM PRETRAINED MODEL")
        saved_state_dict = torch.load(args.restore_from,
                                      map_location=torch.device('cpu'))
        new_params = deeplab.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not i_parts[0] == 'fc':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
        Log.info("load pretrained models")
        deeplab.load_state_dict(new_params, strict=False)
    else:
        Log.info("train from scratch")

    args.world_size = 1

    if 'WORLD_SIZE' in os.environ and args.apex:
        args.apex = int(os.environ['WORLD_SIZE']) > 1
        args.world_size = int(os.environ['WORLD_SIZE'])
        print("Total world size: ", int(os.environ['WORLD_SIZE']))

    if not args.gpu == None:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    h, w = args.input_size, args.input_size
    input_size = (h, w)

    # Set the device according to local_rank.
    #    torch.cuda.set_device(args.local_rank)
    #    Log.info("Local Rank: {}".format(args.local_rank))
    #    torch.distributed.init_process_group(backend='nccl',
    #                                         init_method='env://')
    # set optimizer
    optimizer = optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad, deeplab.parameters()),
            'lr': args.learning_rate
        }],
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    optimizer.zero_grad()
    deeplab.cuda()
    # models transformation
    #    model = DistributedDataParallel(deeplab)
    #    model = apex.parallel.convert_syncbn_model(model)
    model = deeplab
    model.train()
    model.float()
    model.cuda()

    # set loss function
    if args.ohem:
        criterion = CriterionOhemDSN(
            thresh=args.ohem_thres,
            min_kept=args.ohem_keep)  # OHEM CrossEntrop
    else:
        criterion = CriterionDSN()  # CrossEntropy
    criterion.cuda()

    cudnn.benchmark = True

    # if args.world_size == 1:
    #     print(model)

    # this is a little different from mul-gpu traning setting in distributed training
    # because each trainloader is a process that sample from the dataset class.
    batch_size = args.batch_size_per_gpu
    max_iters = args.num_steps * batch_size
    # set data loader

    #PASCAL - VOC -----------------

    from torchvision import transforms
    augs = transforms.Compose([
        transforms.RandomResizedCrop(300),
        transforms.RandomRotation(20),
        transforms.ToTensor(),
        transforms.Normalize([0.4589, 0.4355, 0.4032],
                             [0.2239, 0.2186, 0.2206])
    ])
    if args.data_set == 'pascalvoc':
        data_set = VOCSegmentation(args.data_dir,
                                   image_set='val',
                                   scale=args.random_scale,
                                   mean=IMG_MEAN,
                                   vars=IMG_VARS,
                                   transforms=augs)

    elif args.data_set == 'cityscapes':
        data_set = Cityscapes(args.data_dir,
                              args.data_list,
                              crop_size=input_size,
                              scale=args.random_scale,
                              mirror=args.random_mirror,
                              mean=IMG_MEAN,
                              vars=IMG_VARS,
                              RGB=args.rgb)

    # instance_count = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    # for _, label in data_set:
    #     for pixel in label.flatten():
    #         if(int(pixel) == 255):
    #             pixel = 21
    #         instance_count[int(pixel)] += 1
    # print(instance_count)
    # sys.exit()

    trainloader = data.DataLoader(data_set,
                                  batch_size=args.batch_size_per_gpu,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    print("trainloader", len(trainloader))

    torch.cuda.empty_cache()

    # start training:
    iter_no = 0
    for epoch in range(args.num_steps):
        print("epoch " + str(epoch + 1))
        total_loss = 0
        total_correct = 0

        for i_iter, batch in enumerate(trainloader):
            if i_iter % 100 == 0:
                print("iteration " + str(i_iter + 1))
            images, labels = batch
            images = images.cuda()
            labels = labels.long().cuda()

            optimizer.zero_grad()
            lr = adjust_learning_rate(optimizer, args, i_iter,
                                      len(trainloader))
            preds = model(images)

            loss = criterion(preds, labels)
            total_loss += loss.item()
            writer.add_scalar("Loss_vs_Iteration", loss.item(), iter_no)
            iter_no += 1
            loss.backward()
            optimizer.step()

        writer.add_scalar("Loss_vs_Epoch", total_loss / len(trainloader),
                          epoch)
        # writer.add_scaler("Correct", total_correct, epoch)
        # writer.add_scaler("Accuracy",total_correct / len(dataset), epoch)
        # reduce_loss = all_reduce_tensor(loss,world_size=args.gpu_num)
        # if args.local_rank == 0:
        #     # Log.info('iter = {} of {} completed, lr={}, loss = {}'.format(i_iter,
        #     #                                                          len(trainloader), lr, reduce_loss.data.cpu().numpy()))
        #     if i_iter % args.save_pred_every == 0 and i_iter > args.save_start:
        #         print('save models ...')

        #         torch.save(deeplab.state_dict(), osp.join(args.save_dir, str(args.arch) + str(i_iter) + '.pth'))

        if args.local_rank == 0:
            if epoch % 9 == 0:
                print('save models ...')
                torch.save(
                    deeplab.state_dict(),
                    osp.join(args.save_dir,
                             str(args.arch) + str(i_iter) + '.pth'))

    writer.close()

    end = timeit.default_timer()

    if args.local_rank == 0:
        Log.info("Training cost: " + str(end - start) + 'seconds')
        Log.info("Save final models")
        torch.save(
            deeplab.state_dict(),
            osp.join(
                args.save_dir,
                str(args.arch) + '_' + str(args.num_steps) + 'epoch_' +
                str(args.batch_size_per_gpu) + '.pth'))