Example #1
0
def main():
    """Create the model and start the training."""
    print(args)
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    writer = SummaryWriter(args.snapshot_dir)
    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]

    # cudnn related setting
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True

    deeplab = get_cls_net(config=config,
                          num_classes=args.num_classes,
                          is_train=True)

    print('-------Load Weight', args.restore_from)
    saved_state_dict = torch.load(args.restore_from)

    if args.start_epoch > 0:
        model = DataParallelModel(deeplab)
        model.load_state_dict(saved_state_dict['state_dict'])
    else:
        new_params = deeplab.state_dict().copy()
        state_dict_pretrain = saved_state_dict

        for state_name in state_dict_pretrain:
            if state_name in new_params:
                new_params[state_name] = state_dict_pretrain[state_name]
            else:
                print('NOT LOAD', state_name)
        deeplab.load_state_dict(new_params)
        model = DataParallelModel(deeplab)
    print('-------Load Weight Finish', args.restore_from)

    model.cuda()

    criterion = CriterionAll()
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    print("-------Loading data...")
    if 'vehicle_parsing_dataset' in args.data_dir:
        parsing_dataset = VPDataSet(args.data_dir,
                                    args.dataset,
                                    crop_size=input_size,
                                    transform=transform)
    elif 'LIP' in args.data_dir:
        parsing_dataset = LIPDataSet(args.data_dir,
                                     args.dataset,
                                     crop_size=input_size,
                                     transform=transform)
    print("Data dir : ", args.data_dir)
    print("Dataset : ", args.dataset)
    trainloader = data.DataLoader(parsing_dataset,
                                  batch_size=args.batch_size * len(gpus),
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    if args.start_epoch > 0:
        optimizer.load_state_dict(saved_state_dict['optimizer'])
        print('-------Load Optimizer', args.restore_from)

    print("-------Start training...")
    total_iters = args.epochs * len(trainloader)
    for epoch in range(args.start_epoch, args.epochs):
        model.train()
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch
            lr = adjust_learning_rate(optimizer, i_iter, total_iters)

            images, labels, _ = batch
            labels = labels.long().cuda(non_blocking=True)
            preds = model(images)

            loss = criterion(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i_iter % 100 == 0:
                writer.add_scalar('learning_rate', lr, i_iter)
                writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

            print(
                f'epoch = {epoch}, iter = {i_iter}/{total_iters}, lr={lr:.6f}, loss = {loss.data.cpu().numpy():.6f}'
            )

        if (epoch + 1) % args.save_step == 0 or epoch == args.epochs:
            time.sleep(10)
            print("-------Saving checkpoint...")
            save_checkpoint(model, epoch, optimizer)

    time.sleep(10)
    save_checkpoint(model, epoch, optimizer)
    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #2
0
def main():
    """Create the model and start the training."""
    print(args)
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    writer = SummaryWriter(args.snapshot_dir)
    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]

    cudnn.enabled = True
    # cudnn related setting
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True

    deeplab = get_cls_net(config=config,
                          num_classes=args.num_classes,
                          is_train=True)
    model = DataParallelModel(deeplab)

    saved_state_dict = torch.load(args.restore_from)

    if args.start_epoch > 0:
        model = DataParallelModel(deeplab)
        model.load_state_dict(saved_state_dict['state_dict'])
    else:
        new_params = model.state_dict().copy()
        state_dict_pretrain = saved_state_dict['state_dict']
        for state_name in state_dict_pretrain:
            if state_name in new_params:
                new_params[state_name] = state_dict_pretrain[state_name]
                #print ('LOAD',state_name)
            else:
                print('NOT LOAD', state_name)
        model.load_state_dict(new_params)

    print('-------Load Weight', args.restore_from)

    model.cuda()

    criterion = CriterionAll2()
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    trainloader = data.DataLoader(LIPDataSet(args.data_dir,
                                             args.dataset,
                                             crop_size=input_size,
                                             transform=transform),
                                  batch_size=args.batch_size * len(gpus),
                                  shuffle=True,
                                  num_workers=4,
                                  pin_memory=True)

    num_samples = 5000
    '''
    list_map = []

    for part in deeplab.path_list:
        list_map = list_map + list(map(id, part.parameters()))
    
    base_params = filter(lambda p: id(p) not in list_map,
                         deeplab.parameters())
    params_list = []
    params_list.append({'params': base_params, 'lr':args.learning_rate*0.1})
    for part in deeplab.path_list:
        params_list.append({'params': part.parameters()})
    print ('len(params_list)',len(params_list))
    '''

    list_map = []

    for part in deeplab.path_list:
        list_map = list_map + list(map(id, part.parameters()))

    base_params = filter(lambda p: id(p) not in list_map, deeplab.parameters())
    params_list = []
    params_list.append({'params': base_params, 'lr': 1e-6})
    for part in deeplab.path_list:
        params_list.append({'params': part.parameters()})
    print('len(params_list)', len(params_list))
    optimizer = torch.optim.SGD(params_list,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.start_epoch > 0:
        optimizer.load_state_dict(saved_state_dict['optimizer'])
        print('========Load Optimizer', args.restore_from)

    optimizer.zero_grad()

    total_iters = args.epochs * len(trainloader)
    for epoch in range(args.start_epoch, args.epochs):
        model.train()
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch
            #lr = adjust_learning_rate(optimizer, i_iter, total_iters)
            lr = adjust_learning_rate_parsing(optimizer, epoch)

            images, labels, _ = batch
            labels = labels.long().cuda(non_blocking=True)
            preds = model(images)

            loss = criterion(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i_iter % 100 == 0:
                writer.add_scalar('learning_rate', lr, i_iter)
                writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

            print('epoch = {}, iter = {} of {} completed,lr={}, loss = {}'.
                  format(epoch, i_iter, total_iters, lr,
                         loss.data.cpu().numpy()))
        if epoch % 2 == 0 or epoch == args.epochs:
            time.sleep(10)
            save_checkpoint(model, epoch, optimizer)

        # parsing_preds, scales, centers = valid(model, valloader, input_size,  num_samples, len(gpus))

        # mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size)

        # print(mIoU)
        # writer.add_scalars('mIoU', mIoU, epoch)
    time.sleep(10)
    save_checkpoint(model, epoch, optimizer)
    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #3
0
def main():
    """Create the model and start the training."""
    print (args)
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    writer = SummaryWriter(args.snapshot_dir)
    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]

    cudnn.enabled = True
    # cudnn related setting
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True
 

    deeplab = Res_Deeplab(num_classes=args.num_classes)

    # dump_input = torch.rand((args.batch_size, 3, input_size[0], input_size[1]))
    # writer.add_graph(deeplab.cuda(), dump_input.cuda(), verbose=False)

    saved_state_dict = torch.load(args.restore_from)

    if args.start_epoch >0:
        model = DataParallelModel(deeplab)
        model.load_state_dict(saved_state_dict['state_dict'])
    else:
        new_params = deeplab.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            # print(i_parts)
            if not i_parts[0] == 'fc':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
        deeplab.load_state_dict(new_params)
        model = DataParallelModel(deeplab)
    print ('-------Load Weight',args.restore_from)
    model.cuda()

    criterion = CriterionAll()
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    trainloader = data.DataLoader(LIPDataSet(args.data_dir, args.dataset, crop_size=input_size, transform=transform),
                                  batch_size=args.batch_size * len(gpus), shuffle=True, num_workers=4,
                                  pin_memory=True)
    lip_dataset = LIPDataSet(args.data_dir, 'val', crop_size=input_size, transform=transform)
    num_samples = len(lip_dataset)
    
    valloader = data.DataLoader(lip_dataset, batch_size=args.batch_size * len(gpus),
                                 shuffle=False, pin_memory=True)

    optimizer = optim.SGD(
        model.parameters(),
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    
    if args.start_epoch > 0:
        optimizer.load_state_dict(saved_state_dict['optimizer'])
        print ('========Load Optimizer',args.restore_from)
    optimizer.zero_grad()

    total_iters = args.epochs * len(trainloader)
    for epoch in range(args.start_epoch, args.epochs):
        model.train()
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch
            lr = adjust_learning_rate(optimizer, i_iter, total_iters)

            images, labels, edges, _ = batch
            labels = labels.long().cuda(non_blocking=True)
            edges = edges.long().cuda(non_blocking=True)

            preds = model(images)

            loss = criterion(preds, [labels, edges])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i_iter % 100 == 0:
                writer.add_scalar('learning_rate', lr, i_iter)
                writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

            # if i_iter % 500 == 0:

                # images_inv = inv_preprocess(images, args.save_num_images)
                # labels_colors = decode_parsing(labels, args.save_num_images, args.num_classes, is_pred=False)
                # edges_colors = decode_parsing(edges, args.save_num_images, 2, is_pred=False)

                # if isinstance(preds, list):
                    # preds = preds[0]
                # preds_colors = decode_parsing(preds[0][-1], args.save_num_images, args.num_classes, is_pred=True)
                # pred_edges = decode_parsing(preds[1][-1], args.save_num_images, 2, is_pred=True)

                # img = vutils.make_grid(images_inv, normalize=False, scale_each=True)
                # lab = vutils.make_grid(labels_colors, normalize=False, scale_each=True)
                # pred = vutils.make_grid(preds_colors, normalize=False, scale_each=True)
                # edge = vutils.make_grid(edges_colors, normalize=False, scale_each=True)
                # pred_edge = vutils.make_grid(pred_edges, normalize=False, scale_each=True)

                # writer.add_image('Images/', img, i_iter)
                # writer.add_image('Labels/', lab, i_iter)
                # writer.add_image('Preds/', pred, i_iter)
                # writer.add_image('Edges/', edge, i_iter)
                # writer.add_image('PredEdges/', pred_edge, i_iter)

            print('epoch = {}, iter = {} of {} completed,lr={}, loss = {}'.format(epoch, i_iter, total_iters,lr, loss.data.cpu().numpy())) 
        if epoch%5 == 0 or epoch==args.epochs:
            time.sleep(10)
            save_checkpoint(model,epoch,optimizer)

        # parsing_preds, scales, centers = valid(model, valloader, input_size,  num_samples, len(gpus))

        # mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size)

        # print(mIoU)
        # writer.add_scalars('mIoU', mIoU, epoch)
    time.sleep(10)
    save_checkpoint(model,epoch,optimizer)
    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #4
0
def main():
    torch.multiprocessing.set_start_method("spawn", force=True)
    """Create the model and start the evaluation process."""
    args = get_arguments()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    gpus = [int(i) for i in args.gpu.split(',')]
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)

    deeplab = CorrPM_Model(args.num_classes, args.num_points)
    if len(gpus) > 1:
        model = DataParallelModel(deeplab)
    else:
        model = deeplab

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    if args.data_name == 'lip':
        lip_dataset = LIPDataSet(args.data_dir,
                                 VAL_POSE_ANNO_FILE,
                                 args.dataset,
                                 crop_size=input_size,
                                 transform=transform)
        num_samples = len(lip_dataset)
        valloader = data.DataLoader(lip_dataset,
                                    batch_size=args.batch_size * len(gpus),
                                    shuffle=False,
                                    num_workers=4,
                                    pin_memory=True)

    restore_from = args.restore_from
    state_dict = model.state_dict().copy()
    state_dict_old = torch.load(restore_from)

    for key in state_dict.keys():
        if key not in state_dict_old.keys():
            print(key)
    for key, nkey in zip(state_dict_old.keys(), state_dict.keys()):
        if key != nkey:
            state_dict[key[7:]] = deepcopy(state_dict_old[key])
        else:
            state_dict[key] = deepcopy(state_dict_old[key])

    model.load_state_dict(state_dict)
    model.eval()
    model.cuda()

    parsing_preds, scales, centers = valid(model, valloader, input_size,
                                           num_samples, len(gpus))

    mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes,
                            args.data_dir, input_size, args.dataset)
    print(mIoU)

    end = datetime.datetime.now()
    print(end - start, 'seconds')
    print(end)
Example #5
0
def main():
    """Create the model and start the training."""

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    timestramp = args.date
    writer = SummaryWriter(os.path.join(args.snapshot_dir, timestramp))
    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]

    cudnn.enabled = True
    # cudnn related setting
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True

    deeplab = Res_Deeplab(num_classes=args.num_classes)

    # dump_input = torch.rand((args.batch_size, 3, input_size[0], input_size[1]))
    # writer.add_graph(deeplab.cuda(), dump_input.cuda(), verbose=False)

    model = DataParallelModel(deeplab)
    if args.resume:
        # when restore form the same network, it is useful here
        checkpoint = torch.load(args.restore_from)
        model.load_state_dict(checkpoint['net'])
        args.start_epoch = checkpoint['epoch']
    else:
        saved_state_dict = torch.load(args.restore_from)
        new_params = deeplab.state_dict().copy()
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not i_parts[0] == 'fc':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i]
        deeplab.load_state_dict(new_params)

    model.cuda()

    criterion = CriterionAll()
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    trainloader = data.DataLoader(LIPDataSet(args.data_dir,
                                             args.dataset,
                                             crop_size=input_size,
                                             transform=transform),
                                  batch_size=args.batch_size * len(gpus),
                                  shuffle=True,
                                  num_workers=2,
                                  pin_memory=True)
    lip_dataset = LIPDataSet(args.data_dir,
                             'val',
                             crop_size=input_size,
                             transform=transform)
    num_samples = len(lip_dataset)

    valloader = data.DataLoader(lip_dataset,
                                batch_size=args.batch_size * len(gpus),
                                shuffle=False,
                                pin_memory=True)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    optimizer.zero_grad()

    total_iters = args.epochs * len(trainloader)
    log = Logger(os.path.join(args.log_dir, '{}_train.log'.format(timestramp)),
                 level='debug')
    for epoch in range(args.start_epoch, args.epochs):
        model.train()
        #lr = adjust_learning_rate_pose(optimizer, epoch)
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch
            lr = adjust_learning_rate(optimizer, i_iter, total_iters)

            images, labels, r1, r2, r3, r4, l0, l1, l2, l3, l4, l5, _ = batch
            labels = labels.long().cuda(non_blocking=True)
            r1 = r1.long().cuda(non_blocking=True)
            r2 = r2.long().cuda(non_blocking=True)
            r3 = r3.long().cuda(non_blocking=True)
            r4 = r4.long().cuda(non_blocking=True)
            l0 = l0.long().cuda(non_blocking=True)
            l1 = l1.long().cuda(non_blocking=True)
            l2 = l2.long().cuda(non_blocking=True)
            l3 = l3.long().cuda(non_blocking=True)
            l4 = l4.long().cuda(non_blocking=True)
            l5 = l5.long().cuda(non_blocking=True)

            preds = model(images)

            loss = criterion(
                preds, [[labels], [r1, r2, r3, r4], [l0, l1, l2, l3, l4, l5]])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i_iter % 100 == 0:
                writer.add_scalar('learning_rate', lr, i_iter)
                writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

            if i_iter % 10 == 0:
                log.logger.info(
                    'epoch = {} iter = {} of {} completed, lr = {}, loss = {}'.
                    format(epoch, i_iter, total_iters, lr,
                           loss.data.cpu().numpy()))
        parsing_preds, scales, centers = valid(model, valloader, input_size,
                                               num_samples, len(gpus))
        mIoU = compute_mean_ioU(parsing_preds, scales, centers,
                                args.num_classes, args.data_dir, input_size)

        log.logger.info('epoch = {}'.format(epoch))
        log.logger.info(str(mIoU))
        writer.add_scalars('mIoU', mIoU, epoch)

        # save the model snapshot
        state = {"net": model.module.state_dict(), "epoch": epoch}

        torch.save(
            state,
            osp.join(args.snapshot_dir, timestramp,
                     'LIP_epoch_' + str(epoch) + '.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')
Example #6
0
def main():
    """Create the model and start the training."""
    print(args)
    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    writer = SummaryWriter(args.snapshot_dir)
    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    h, w = map(int, args.input_size.split(','))
    input_size = [h, w]

    # cudnn related setting
    cudnn.enabled = True
    cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.enabled = True

    deeplab = get_resnet101_asp_oc_dsn(num_classes=args.num_classes)

    # dump_input = torch.rand((args.batch_size, 3, input_size[0], input_size[1]))
    # writer.add_graph(deeplab.cuda(), dump_input.cuda(), verbose=False)

    saved_state_dict = torch.load(args.restore_from)

    if args.start_epoch > 0:
        model = DataParallelModel(deeplab)
        #model = torch.nn.parallel.DistributedDataParallel(deeplab)
        model.load_state_dict(saved_state_dict['state_dict'])
    else:
        new_params = deeplab.state_dict().copy()
        state_dict_pretrain = saved_state_dict  #['state_dict']

        for state_name in state_dict_pretrain:
            # splits = i.split('.')
            # state_name = '.'.join(splits[1:])
            if state_name in new_params:
                new_params[state_name] = state_dict_pretrain[state_name]
            else:
                print('NOT LOAD', state_name)
        deeplab.load_state_dict(new_params)
        model = DataParallelModel(deeplab)
        #model = torch.nn.parallel.DistributedDataParallel(deeplab)
    print('-------Load Weight', args.restore_from)

    model.cuda()

    criterion = LovaszSoftmaxDSN(input_size)
    print('LOSS1: LovaszSoftmaxDSN')
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    criterion_softmax = CriterionDSN()
    print('LOSS2: CriterionDSN')
    criterion_softmax = DataParallelCriterion(criterion_softmax)
    criterion_softmax.cuda()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    trainloader = data.DataLoader(LIPDataSet(args.data_dir,
                                             args.dataset,
                                             crop_size=input_size,
                                             transform=transform,
                                             list_path=args.list_path),
                                  batch_size=args.batch_size * len(gpus),
                                  shuffle=True,
                                  num_workers=4,
                                  pin_memory=True)

    num_samples = 5000
    '''
    list_map = []

    for part in deeplab.path_list:
        list_map = list_map + list(map(id, part.parameters()))
    
    base_params = filter(lambda p: id(p) not in list_map,
                         deeplab.parameters())
    params_list = []
    params_list.append({'params': base_params, 'lr':args.learning_rate*0.1})
    for part in deeplab.path_list:
        params_list.append({'params': part.parameters()})
    print ('len(params_list)',len(params_list))
    '''
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    if args.start_epoch > 0:
        optimizer.load_state_dict(saved_state_dict['optimizer'])
        print('========Load Optimizer', args.restore_from)

    total_iters = args.epochs * len(trainloader)
    for epoch in range(args.start_epoch, args.epochs):
        model.train()
        for i_iter, batch in enumerate(trainloader):
            i_iter += len(trainloader) * epoch
            lr = adjust_learning_rate(optimizer, i_iter, total_iters)

            images, labels, _ = batch
            labels = labels.long().cuda(non_blocking=True)
            preds = model(images)

            loss1 = criterion(preds, labels)
            loss2 = criterion_softmax(preds, labels)
            loss = loss1 + loss2
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i_iter % 100 == 0:
                writer.add_scalar('learning_rate', lr, i_iter)
                writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

            # if i_iter % 500 == 0:

            # images_inv = inv_preprocess(images, args.save_num_images)
            # labels_colors = decode_parsing(labels, args.save_num_images, args.num_classes, is_pred=False)
            # edges_colors = decode_parsing(edges, args.save_num_images, 2, is_pred=False)

            # if isinstance(preds, list):
            # preds = preds[0]
            # preds_colors = decode_parsing(preds[0][-1], args.save_num_images, args.num_classes, is_pred=True)
            # pred_edges = decode_parsing(preds[1][-1], args.save_num_images, 2, is_pred=True)

            # img = vutils.make_grid(images_inv, normalize=False, scale_each=True)
            # lab = vutils.make_grid(labels_colors, normalize=False, scale_each=True)
            # pred = vutils.make_grid(preds_colors, normalize=False, scale_each=True)
            # edge = vutils.make_grid(edges_colors, normalize=False, scale_each=True)
            # pred_edge = vutils.make_grid(pred_edges, normalize=False, scale_each=True)

            # writer.add_image('Images/', img, i_iter)
            # writer.add_image('Labels/', lab, i_iter)
            # writer.add_image('Preds/', pred, i_iter)
            # writer.add_image('Edges/', edge, i_iter)
            # writer.add_image('PredEdges/', pred_edge, i_iter)

            print(
                'epoch = {}, iter = {} of {} completed,lr={:.4f}, loss = {:.4f}, IoU_loss = {:.4f}, BCE_loss = {:.4f}'
                .format(epoch, i_iter, total_iters, lr,
                        loss.data.cpu().numpy(),
                        loss1.data.cpu().numpy(),
                        loss2.data.cpu().numpy()))
        if epoch % args.save_step == 0 or epoch == args.epochs:
            time.sleep(10)
            save_checkpoint(model, epoch, optimizer)

        # parsing_preds, scales, centers = valid(model, valloader, input_size,  num_samples, len(gpus))
        # mIoU = compute_mean_ioU(parsing_preds, scales, centers, args.num_classes, args.data_dir, input_size)
        # print(mIoU)
        # writer.add_scalars('mIoU', mIoU, epoch)

    time.sleep(10)
    save_checkpoint(model, epoch, optimizer)
    end = timeit.default_timer()
    print(end - start, 'seconds')
def main():
    args = get_arguments()
    print(args)

    start_epoch = 0
    cycle_n = 0

    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file:
        json.dump(vars(args), opt_file)

    gpus = [int(i) for i in args.gpu.split(',')]
    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    input_size = list(map(int, args.input_size.split(',')))

    cudnn.enabled = True
    cudnn.benchmark = True

    # Model Initialization
    AugmentCE2P = networks.init_model(args.arch,
                                      num_classes=args.num_classes,
                                      pretrained=args.imagenet_pretrain)
    model = DataParallelModel(AugmentCE2P)
    model.cuda()

    IMAGE_MEAN = AugmentCE2P.mean
    IMAGE_STD = AugmentCE2P.std
    INPUT_SPACE = AugmentCE2P.input_space
    print('image mean: {}'.format(IMAGE_MEAN))
    print('image std: {}'.format(IMAGE_STD))
    print('input space:{}'.format(INPUT_SPACE))

    restore_from = args.model_restore
    if os.path.exists(restore_from):
        print('Resume training from {}'.format(restore_from))
        checkpoint = torch.load(restore_from)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']

    SCHP_AugmentCE2P = networks.init_model(args.arch,
                                           num_classes=args.num_classes,
                                           pretrained=args.imagenet_pretrain)
    schp_model = DataParallelModel(SCHP_AugmentCE2P)
    schp_model.cuda()

    if os.path.exists(args.schp_restore):
        print('Resuming schp checkpoint from {}'.format(args.schp_restore))
        schp_checkpoint = torch.load(args.schp_restore)
        schp_model_state_dict = schp_checkpoint['state_dict']
        cycle_n = schp_checkpoint['cycle_n']
        schp_model.load_state_dict(schp_model_state_dict)

    # Loss Function
    criterion = CriterionAll(lambda_1=args.lambda_s,
                             lambda_2=args.lambda_e,
                             lambda_3=args.lambda_c,
                             num_classes=args.num_classes)
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    # Data Loader
    if INPUT_SPACE == 'BGR':
        print('BGR Transformation')
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD),
        ])

    elif INPUT_SPACE == 'RGB':
        print('RGB Transformation')
        transform = transforms.Compose([
            transforms.ToTensor(),
            BGR2RGB_transform(),
            transforms.Normalize(mean=IMAGE_MEAN, std=IMAGE_STD),
        ])

    train_dataset = LIPDataSet(args.data_dir,
                               args.split_name,
                               crop_size=input_size,
                               transform=transform)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size * len(gpus),
                                   num_workers=16,
                                   shuffle=True,
                                   pin_memory=True,
                                   drop_last=True)
    print('Total training samples: {}'.format(len(train_dataset)))

    # Optimizer Initialization
    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    lr_scheduler = SGDRScheduler(optimizer,
                                 total_epoch=args.epochs,
                                 eta_min=args.learning_rate / 100,
                                 warmup_epoch=10,
                                 start_cyclical=args.schp_start,
                                 cyclical_base_lr=args.learning_rate / 2,
                                 cyclical_epoch=args.cycle_epochs)

    total_iters = args.epochs * len(train_loader)
    start = timeit.default_timer()
    for epoch in range(start_epoch, args.epochs):
        lr_scheduler.step(epoch=epoch)
        lr = lr_scheduler.get_lr()[0]

        model.train()
        for i_iter, batch in enumerate(train_loader):
            i_iter += len(train_loader) * epoch

            images, labels, _ = batch
            labels = labels.cuda(non_blocking=True)

            edges = generate_edge_tensor(labels)
            labels = labels.type(torch.cuda.LongTensor)
            edges = edges.type(torch.cuda.LongTensor)

            preds = model(images)

            # Online Self Correction Cycle with Label Refinement
            if cycle_n >= 1:
                with torch.no_grad():
                    soft_preds = schp_model(images)
                    soft_parsing = []
                    soft_edge = []
                    for soft_pred in soft_preds:
                        soft_parsing.append(soft_pred[0][-1])
                        soft_edge.append(soft_pred[1][-1])
                    soft_preds = torch.cat(soft_parsing, dim=0)
                    soft_edges = torch.cat(soft_edge, dim=0)
            else:
                soft_preds = None
                soft_edges = None

            loss = criterion(preds, [labels, edges, soft_preds, soft_edges],
                             cycle_n)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i_iter % 100 == 0:
                print('iter = {} of {} completed, lr = {}, loss = {}'.format(
                    i_iter, total_iters, lr,
                    loss.data.cpu().numpy()))
        if (epoch + 1) % (args.eval_epochs) == 0:
            schp.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                },
                False,
                args.log_dir,
                filename='checkpoint_{}.pth.tar'.format(epoch + 1))

        # Self Correction Cycle with Model Aggregation
        if (epoch + 1) >= args.schp_start and (
                epoch + 1 - args.schp_start) % args.cycle_epochs == 0:
            print('Self-correction cycle number {}'.format(cycle_n))
            schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1))
            cycle_n += 1
            schp.bn_re_estimate(train_loader, schp_model)
            schp.save_schp_checkpoint(
                {
                    'state_dict': schp_model.state_dict(),
                    'cycle_n': cycle_n,
                },
                False,
                args.log_dir,
                filename='schp_{}_checkpoint.pth.tar'.format(cycle_n))

        torch.cuda.empty_cache()
        end = timeit.default_timer()
        print('epoch = {} of {} completed using {} s'.format(
            epoch, args.epochs, (end - start) / (epoch - start_epoch + 1)))

    end = timeit.default_timer()
    print('Training Finished in {} seconds'.format(end - start))