예제 #1
0
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    logger = logging.getLogger('3dgnn')
    log_path = './experiment/' + str(
        datetime.datetime.now().strftime('%Y-%m-%d-%H')).replace(' ',
                                                                 '/') + '/'
    print('log path is:', log_path)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(log_path + 'save/')
    hdlr = logging.FileHandler(log_path + 'log.txt')
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.INFO)
    logger.info("Loading data...")
    print("Loading data...")
    '''idx_to_label = {0: '<UNK>', 1: 'beam', 2: 'board', 3: 'bookcase', 4: 'ceiling', 5: 'chair', 6: 'clutter',
                    7: 'column',
                    8: 'door', 9: 'floor', 10: 'sofa', 11: 'table', 12: 'wall', 13: 'window'}*'''

    if args.is_2_headed:
        dataset_tr = nyud2headed.Dataset(flip_prob=config.flip_prob,
                                         crop_type='Random',
                                         crop_size=config.crop_size)
    else:
        dataset_tr = nyudv2.Dataset(flip_prob=config.flip_prob,
                                    crop_type='Random',
                                    crop_size=config.crop_size)
    idx_to_label = dataset_tr.label_names
    if args.is_2_headed:
        idx_to_label2 = dataset_tr.label2_names

    dataloader_tr = DataLoader(dataset_tr,
                               batch_size=args.batchsize,
                               shuffle=True,
                               num_workers=config.workers_tr,
                               drop_last=False,
                               pin_memory=True)

    if args.is_2_headed:
        dataset_va = nyud2headed.Dataset(flip_prob=0.0,
                                         crop_type='Center',
                                         crop_size=config.crop_size)
    else:
        dataset_va = nyudv2.Dataset(flip_prob=0.0,
                                    crop_type='Center',
                                    crop_size=config.crop_size)
    dataloader_va = DataLoader(dataset_va,
                               batch_size=args.batchsize,
                               shuffle=False,
                               num_workers=config.workers_va,
                               drop_last=False,
                               pin_memory=True)
    cv2.setNumThreads(config.workers_tr)

    logger.info("Preparing model...")
    print("Preparing model...")

    class_weights = [0.0] + [1.0 for i in range(1, len(idx_to_label))]
    nclasses = len(class_weights)
    if args.is_2_headed:
        nclasses1 = nclasses
        class2_weights = [0.0] + [1.0 for i in range(1, len(idx_to_label2))]
        nclasses2 = len(class2_weights)
        model = Model2Headed(nclasses1, nclasses2, config.mlp_num_layers,
                             config.use_gpu)
        loss2 = nn.NLLLoss(reduce=not config.use_bootstrap_loss,
                           weight=torch.FloatTensor(class2_weights))
    else:
        model = Model(nclasses, config.mlp_num_layers, config.use_gpu)
    loss = nn.NLLLoss(reduce=not config.use_bootstrap_loss,
                      weight=torch.FloatTensor(class_weights))

    softmax = nn.Softmax(dim=1)
    log_softmax = nn.LogSoftmax(dim=1)

    if config.use_gpu:
        model = model.cuda()
        loss = loss.cuda()
        if args.is_2_headed:
            loss2 = loss2.cuda()
        softmax = softmax.cuda()
        log_softmax = log_softmax.cuda()

    optimizer = torch.optim.Adam([{
        'params': model.decoder.parameters()
    }, {
        'params': model.gnn.parameters(),
        'lr': config.gnn_initial_lr
    }],
                                 lr=config.base_initial_lr,
                                 betas=config.betas,
                                 eps=config.eps,
                                 weight_decay=config.weight_decay)

    if config.lr_schedule_type == 'exp':
        lambda1 = lambda epoch: pow(
            (1 - ((epoch - 1) / args.num_epochs)), config.lr_decay)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lambda1)
    elif config.lr_schedule_type == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=config.lr_decay, patience=config.lr_patience)
    else:
        print('bad scheduler')
        exit(1)

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    logger.info("Number of trainable parameters: %d", params)

    def get_current_learning_rates():
        learning_rates = []
        for param_group in optimizer.param_groups:
            learning_rates.append(param_group['lr'])
        return learning_rates

    def eval_set(dataloader):
        model.eval()

        with torch.no_grad():
            loss_sum = 0.0
            if config.use_gpu:
                confusion_matrix = torch.cuda.FloatTensor(
                    np.zeros(len(idx_to_label)**2))
            else:
                confusion_matrix = torch.FloatTensor(
                    np.zeros(len(idx_to_label)**2))

            start_time = time.time()

            for batch_idx, rgbd_label_xy in tqdm(enumerate(dataloader),
                                                 total=len(dataloader),
                                                 smoothing=0.9):
                x = rgbd_label_xy[0]
                xy = rgbd_label_xy[2]
                target = rgbd_label_xy[1].long()
                x = x.float()
                xy = xy.float()

                input = x.permute(0, 3, 1, 2).contiguous()
                xy = xy.permute(0, 3, 1, 2).contiguous()
                if config.use_gpu:
                    input = input.cuda()
                    xy = xy.cuda()
                    target = target.cuda()

                output = model(input,
                               gnn_iterations=config.gnn_iterations,
                               k=config.gnn_k,
                               xy=xy,
                               use_gnn=config.use_gnn)
                # if args.is_2_headed:
                #     output1, output2 = model(input, gnn_iterations=config.gnn_iterations, k=config.gnn_k, xy=xy,
                #                              use_gnn=config.use_gnn)

                if config.use_bootstrap_loss:
                    loss_per_pixel = loss.forward(log_softmax(output.float()),
                                                  target)
                    topk, indices = torch.topk(
                        loss_per_pixel.view(output.size()[0], -1),
                        int((config.crop_size**2) * config.bootstrap_rate))
                    loss_ = torch.mean(topk)
                else:
                    loss_ = loss.forward(log_softmax(output.float()), target)
                loss_sum += loss_

                pred = output.permute(0, 2, 3, 1).contiguous()
                pred = pred.view(-1, nclasses)
                pred = softmax(pred)
                pred_max_val, pred_arg_max = pred.max(1)

                pairs = target.view(-1) * len(
                    idx_to_label) + pred_arg_max.view(-1)
                for i in range(len(idx_to_label)**2):
                    cumu = pairs.eq(i).float().sum()
                    confusion_matrix[i] += cumu.item()

            sys.stdout.write(" - Eval time: {:.2f}s \n".format(time.time() -
                                                               start_time))
            loss_sum /= len(dataloader)

            confusion_matrix = confusion_matrix.cpu().numpy().reshape(
                (len(idx_to_label), len(idx_to_label)))
            class_iou = np.zeros(len(idx_to_label))
            confusion_matrix[0, :] = np.zeros(len(idx_to_label))
            confusion_matrix[:, 0] = np.zeros(len(idx_to_label))
            for i in range(1, len(idx_to_label)):
                tot = np.sum(confusion_matrix[i, :]) + np.sum(
                    confusion_matrix[:, i]) - confusion_matrix[i, i]
                if tot == 0:
                    class_iou[i] = 0
                else:
                    class_iou[i] = confusion_matrix[i, i] / tot

        return loss_sum.item(), class_iou, confusion_matrix

    '''Training parameter'''
    model_to_load = args.pretrain
    logger.info("num_epochs: %d", args.num_epochs)
    print("Number of epochs: %d" % args.num_epochs)
    interval_to_show = 100

    train_losses = []
    eval_losses = []

    if model_to_load:
        logger.info("Loading old model...")
        print("Loading old model...")
        model.load_state_dict(torch.load(model_to_load))
    else:
        # print("here")
        # exit(0)
        logger.info("Starting training from scratch...")
        print("Starting training from scratch...")
    '''Training'''
    for epoch in range(1, args.num_epochs + 1):
        print("epoch", epoch)
        batch_loss_avg = 0
        if config.lr_schedule_type == 'exp':
            scheduler.step(epoch)
        for batch_idx, rgbd_label_xy in tqdm(enumerate(dataloader_tr),
                                             total=len(dataloader_tr),
                                             smoothing=0.9):
            x = rgbd_label_xy[0]
            target = rgbd_label_xy[1].long()
            if args.is_2_headed:
                target2 = rgbd_label_xy[3].long()
            xy = rgbd_label_xy[2]
            x = x.float()
            xy = xy.float()

            input = x.permute(0, 3, 1, 2).contiguous()
            input = input.type(torch.FloatTensor)

            if config.use_gpu:
                input = input.cuda()
                xy = xy.cuda()
                target = target.cuda()
                if args.is_2_headed:
                    target2 = target2.cuda()

            xy = xy.permute(0, 3, 1, 2).contiguous()

            optimizer.zero_grad()
            model.train()

            if args.is_2_headed:
                output1, output2 = model(input,
                                         gnn_iterations=config.gnn_iterations,
                                         k=config.gnn_k,
                                         xy=xy,
                                         use_gnn=config.use_gnn)
            else:
                output = model(input,
                               gnn_iterations=config.gnn_iterations,
                               k=config.gnn_k,
                               xy=xy,
                               use_gnn=config.use_gnn)

            if config.use_bootstrap_loss:
                loss_per_pixel = loss.forward(log_softmax(output.float()),
                                              target)
                topk, indices = torch.topk(
                    loss_per_pixel.view(output.size()[0], -1),
                    int((config.crop_size**2) * config.bootstrap_rate))
                loss_ = torch.mean(topk)
            else:
                if args.is_2_headed:
                    loss_ = loss.forward(log_softmax(
                        output1.float()), target) + loss2.forward(
                            log_softmax(output2.float()), target2)
                else:
                    loss_ = loss.forward(log_softmax(output.float()), target)

            loss_.backward()
            optimizer.step()

            batch_loss_avg += loss_.item()

            if batch_idx % interval_to_show == 0 and batch_idx > 0:
                batch_loss_avg /= interval_to_show
                train_losses.append(batch_loss_avg)
                logger.info("E%dB%d Batch loss average: %s", epoch, batch_idx,
                            batch_loss_avg)
                print('\rEpoch:{}, Batch:{}, loss average:{}'.format(
                    epoch, batch_idx, batch_loss_avg))
                batch_loss_avg = 0

        batch_idx = len(dataloader_tr)
        logger.info("E%dB%d Saving model...", epoch, batch_idx)

        torch.save(model.state_dict(),
                   log_path + '/save/' + 'checkpoint_' + str(epoch) + '.pth')
        '''Evaluation'''
        # eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va)
        # eval_losses.append(eval_loss)
        #
        # if config.lr_schedule_type == 'plateau':
        #     scheduler.step(eval_loss)
        print('Learning ...')
        logger.info("E%dB%d Def learning rate: %s", epoch, batch_idx,
                    get_current_learning_rates()[0])
        print('Epoch{} Def learning rate: {}'.format(
            epoch,
            get_current_learning_rates()[0]))
        logger.info("E%dB%d GNN learning rate: %s", epoch, batch_idx,
                    get_current_learning_rates()[1])
        print('Epoch{} GNN learning rate: {}'.format(
            epoch,
            get_current_learning_rates()[1]))
        # logger.info("E%dB%d Eval loss: %s", epoch, batch_idx, eval_loss)
        # print('Epoch{} Eval loss: {}'.format(epoch, eval_loss))
        # logger.info("E%dB%d Class IoU:", epoch, batch_idx)
        # print('Epoch{} Class IoU:'.format(epoch))
        # for cl in range(len(idx_to_label)):
        #     logger.info("%+10s: %-10s" % (idx_to_label[cl], class_iou[cl]))
        #     print('{}:{}'.format(idx_to_label[cl], class_iou[cl]))
        # logger.info("Mean IoU: %s", np.mean(class_iou[1:]))
        # print("Mean IoU: %.2f" % np.mean(class_iou[1:]))
        # logger.info("E%dB%d Confusion matrix:", epoch, batch_idx)
        # logger.info(confusion_matrix)

    logger.info("Finished training!")
    logger.info("Saving model...")
    print('Saving final model...')
    torch.save(model.state_dict(), log_path + '/save/3dgnn_finish.pth')
예제 #2
0
def main():
    model_name = '3dgnn_enet'
    current_path = os.getcwd()
    logger = logging.getLogger(model_name)
    log_path = current_path + '/artifacts/'+ str(datetime.datetime.now().strftime('%Y-%m-%d-%H')).replace(' ', '/') + '/'
    print('log path is:',log_path)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
        os.makedirs(log_path + 'save/')
    hdlr = logging.FileHandler(log_path + model_name + '.log')
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.INFO)
    logger.info("Loading data...")
    print("Loading data...")

    label_to_idx = {'<UNK>': 0, 'beam': 1, 'board': 2, 'bookcase': 3, 'ceiling': 4, 'chair': 5, 'clutter': 6,
                    'column': 7,
                    'door': 8, 'floor': 9, 'sofa': 10, 'table': 11, 'wall': 12, 'window': 13}
    idx_to_label = {0: '<UNK>', 1: 'beam', 2: 'board', 3: 'bookcase', 4: 'ceiling', 5: 'chair', 6: 'clutter',
                    7: 'column',
                    8: 'door', 9: 'floor', 10: 'sofa', 11: 'table', 12: 'wall', 13: 'window'}

    '''Data Loader parameter'''
    # Batch size
    batch_size_tr = 4
    batch_size_va = 4
    # Multiple threads loading data
    workers_tr = 4
    workers_va = 4
    # Data augmentation
    flip_prob = 0.5
    crop_size = 0

    dataset_tr = nyudv2.Dataset(flip_prob=flip_prob,crop_type='Random',crop_size=crop_size)
    dataloader_tr = DataLoader(dataset_tr, batch_size=batch_size_tr, shuffle=True,
                               num_workers=workers_tr, drop_last=False, pin_memory=True)

    dataset_va = nyudv2.Dataset(flip_prob=0.0,crop_type='Center',crop_size=crop_size)
    dataloader_va = DataLoader(dataset_va, batch_size=batch_size_va, shuffle=False,
                               num_workers=workers_va, drop_last=False, pin_memory=True)
    cv2.setNumThreads(workers_tr)

    class_weights = [0.0]+[1.0 for i in range(13)]
    nclasses = len(class_weights)
    num_epochs = 50

    '''GNN parameter'''
    use_gnn = True
    gnn_iterations = 3
    gnn_k = 64
    mlp_num_layers = 1

    '''Model parameter'''
    use_bootstrap_loss = False
    bootstrap_rate = 0.25
    use_gpu = True

    logger.info("Preparing model...")
    print("Preparing model...")
    model = Model(nclasses, mlp_num_layers,use_gpu)
    loss = nn.NLLLoss(reduce=not use_bootstrap_loss, weight=torch.FloatTensor(class_weights))
    softmax = nn.Softmax(dim=1)
    log_softmax = nn.LogSoftmax(dim=1)

    if use_gpu:
        model = model.cuda()
        loss = loss.cuda()
        softmax = softmax.cuda()
        log_softmax = log_softmax.cuda()

    '''Optimizer parameter'''
    base_initial_lr = 5e-4
    gnn_initial_lr = 1e-3
    betas = [0.9, 0.999]
    eps = 1e-08
    weight_decay = 1e-4
    lr_schedule_type = 'exp'
    lr_decay = 0.9
    lr_patience = 10

    optimizer = torch.optim.Adam([{'params': model.decoder.parameters()},
                                  {'params': model.gnn.parameters(), 'lr': gnn_initial_lr}],
                                 lr=base_initial_lr, betas=betas, eps=eps, weight_decay=weight_decay)

    if lr_schedule_type == 'exp':
        lambda1 = lambda epoch: pow((1 - ((epoch - 1) / num_epochs)), lr_decay)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
    elif lr_schedule_type == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=lr_decay, patience=lr_patience)
    else:
        print('bad scheduler')
        exit(1)

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    logger.info("Number of trainable parameters: %d", params)

    def get_current_learning_rates():
        learning_rates = []
        for param_group in optimizer.param_groups:
            learning_rates.append(param_group['lr'])
        return learning_rates

    def eval_set(dataloader):
        model.eval()

        with torch.no_grad():
            loss_sum = 0.0
            confusion_matrix = torch.cuda.FloatTensor(np.zeros(14 ** 2))

            start_time = time.time()

            for batch_idx, rgbd_label_xy in enumerate(dataloader):

                sys.stdout.write('\rEvaluating test set... {}/{}'.format(batch_idx + 1, len(dataloader)))
                x = rgbd_label_xy[0]
                xy = rgbd_label_xy[2]
                target = rgbd_label_xy[1].long()
                x = x.float()
                xy = xy.float()

                input = x.permute(0, 3, 1, 2).contiguous()
                xy = xy.permute(0, 3, 1, 2).contiguous()
                if use_gpu:
                    input = input.cuda()
                    xy = xy.cuda()
                    target = target.cuda()

                output = model(input, gnn_iterations=gnn_iterations, k=gnn_k, xy=xy, use_gnn=use_gnn)

                if use_bootstrap_loss:
                    loss_per_pixel = loss.forward(log_softmax(output.float()), target)
                    topk, indices = torch.topk(loss_per_pixel.view(output.size()[0], -1),
                                               int((crop_size ** 2) * bootstrap_rate))
                    loss_ = torch.mean(topk)
                else:
                    loss_ = loss.forward(log_softmax(output.float()), target)
                loss_sum += loss_

                pred = output.permute(0, 2, 3, 1).contiguous()
                pred = pred.view(-1, nclasses)
                pred = softmax(pred)
                pred_max_val, pred_arg_max = pred.max(1)

                pairs = target.view(-1) * 14 + pred_arg_max.view(-1)
                for i in range(14 ** 2):
                    cumu = pairs.eq(i).float().sum()
                    confusion_matrix[i] += cumu.item()

            sys.stdout.write(" - Eval time: {:.2f}s \n".format(time.time() - start_time))
            loss_sum /= len(dataloader)

            confusion_matrix = confusion_matrix.cpu().numpy().reshape((14, 14))
            class_iou = np.zeros(14)
            # we ignore void values
            confusion_matrix[0, :] = np.zeros(14)
            confusion_matrix[:, 0] = np.zeros(14)
            for i in range(1, 14):
                class_iou[i] = confusion_matrix[i, i] / (
                        np.sum(confusion_matrix[i, :]) + np.sum(confusion_matrix[:, i]) - confusion_matrix[i, i])

        return loss_sum.item(), class_iou, confusion_matrix

    '''Training parameter'''
    model_to_load = None
    logger.info("num_epochs: %d", num_epochs)
    print("Number of epochs: %d"%num_epochs)
    interval_to_show = 100

    train_losses = []
    eval_losses = []

    if model_to_load:
        logger.info("Loading old model...")
        print("Loading old model...")
        model.load_state_dict(torch.load(model_to_load))
    else:
        logger.info("Starting training from scratch...")
        print("Starting training from scratch...")

    '''Training'''
    for epoch in range(1, num_epochs + 1):
        batch_loss_avg = 0
        if lr_schedule_type == 'exp':
            scheduler.step(epoch)
        for batch_idx, rgbd_label_xy in enumerate(dataloader_tr):

            sys.stdout.write('\rTraining data set... {}/{}'.format(batch_idx + 1, len(dataloader_tr)))

            x = rgbd_label_xy[0]
            target = rgbd_label_xy[1].long()
            xy = rgbd_label_xy[2]
            x = x.float()
            xy = xy.float()

            input = x.permute(0, 3, 1, 2).contiguous()
            input = input.type(torch.FloatTensor)

            if use_gpu:
                input = input.cuda()
                xy = xy.cuda()
                target = target.cuda()

            xy = xy.permute(0, 3, 1, 2).contiguous()

            optimizer.zero_grad()
            model.train()

            output = model(input, gnn_iterations=gnn_iterations, k=gnn_k, xy=xy, use_gnn=use_gnn)

            if use_bootstrap_loss:
                loss_per_pixel = loss.forward(log_softmax(output.float()), target)
                topk, indices = torch.topk(loss_per_pixel.view(output.size()[0], -1),
                                           int((crop_size ** 2) * bootstrap_rate))
                loss_ = torch.mean(topk)
            else:
                loss_ = loss.forward(log_softmax(output.float()), target)

            loss_.backward()
            optimizer.step()

            batch_loss_avg += loss_.item()

            if batch_idx % interval_to_show == 0 and batch_idx > 0:
                batch_loss_avg /= interval_to_show
                train_losses.append(batch_loss_avg)
                logger.info("E%dB%d Batch loss average: %s", epoch, batch_idx, batch_loss_avg)
                print('\rEpoch:{}, Batch:{}, loss average:{}'.format(epoch, batch_idx, batch_loss_avg))
                batch_loss_avg = 0

        batch_idx = len(dataloader_tr)
        logger.info("E%dB%d Saving model...", epoch, batch_idx)

        torch.save(model.state_dict(),log_path +'/save/'+'checkpoint_'+str(epoch)+'.pth')

        '''Evaluation'''
        eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va)
        eval_losses.append(eval_loss)

        if lr_schedule_type == 'plateau':
            scheduler.step(eval_loss)
        print('Learning ...')
        logger.info("E%dB%d Def learning rate: %s", epoch, batch_idx, get_current_learning_rates()[0])
        print('Epoch{} Def learning rate: {}'.format(epoch, get_current_learning_rates()[0]))
        logger.info("E%dB%d GNN learning rate: %s", epoch, batch_idx, get_current_learning_rates()[1])
        print('Epoch{} GNN learning rate: {}'.format(epoch, get_current_learning_rates()[1]))
        logger.info("E%dB%d Eval loss: %s", epoch, batch_idx, eval_loss)
        print('Epoch{} Eval loss: {}'.format(epoch, eval_loss))
        logger.info("E%dB%d Class IoU:", epoch, batch_idx)
        print('Epoch{} Class IoU:'.format(epoch))
        for cl in range(14):
            logger.info("%+10s: %-10s" % (idx_to_label[cl], class_iou[cl]))
            print('{}:{}'.format(idx_to_label[cl], class_iou[cl]))
        logger.info("Mean IoU: %s", np.mean(class_iou[1:]))
        print("Mean IoU: %.2f"%np.mean(class_iou[1:]))
        logger.info("E%dB%d Confusion matrix:", epoch, batch_idx)
        logger.info(confusion_matrix)


    logger.info("Finished training!")
    logger.info("Saving model...")
    print('Saving final model...')
    torch.save(model.state_dict(), log_path + '/save/3dgnn_enet_finish.pth')
    eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va)
    logger.info("Eval loss: %s", eval_loss)
    logger.info("Class IoU:")
    for cl in range(14):
        logger.info("%+10s: %-10s" % (idx_to_label[cl], class_iou[cl]))
    logger.info("Mean IoU: %s", np.mean(class_iou[1:]))
예제 #3
0
def main(args):
    dataset = nyudv2.Dataset()
    idx_to_label = dataset.label_names
    dataloader = DataLoader(dataset,
                            batch_size=args.batchsize,
                            shuffle=False,
                            num_workers=config.workers_tr,
                            drop_last=False,
                            pin_memory=True)

    class_weights = [0.0] + [1.0 for i in range(1, len(idx_to_label))]
    nclasses = len(class_weights)
    model = Model(nclasses, config.mlp_num_layers, config.use_gpu)
    print("Loading model...")
    model.load_state_dict(
        torch.load(args.path, map_location=lambda storage, loc: storage))
    print("Finsihed Loading model...")

    softmax = nn.Softmax(dim=1)
    confusion_matrix = torch.FloatTensor(np.zeros(len(idx_to_label)**2))
    segmented_path = './segmented/' + str(
        datetime.datetime.now().strftime('%Y-%m-%d-%H')).replace(' ',
                                                                 '/') + '/'
    if not os.path.exists(segmented_path):
        os.makedirs(segmented_path)

    for batch_idx, rgbd_label_xy in tqdm(enumerate(dataloader),
                                         total=len(dataloader),
                                         smoothing=0.9):
        x = rgbd_label_xy[0]  # rgb_hha, label, xy
        xy = rgbd_label_xy[2]
        target = rgbd_label_xy[1].long()
        x = x.float()
        xy = xy.float()

        input = x.permute(0, 3, 1, 2).contiguous()
        xy = xy.permute(0, 3, 1, 2).contiguous()
        if config.use_gpu:
            input = input.cuda()
            xy = xy.cuda()
            target = target.cuda()

        output = model(input,
                       gnn_iterations=config.gnn_iterations,
                       k=config.gnn_k,
                       xy=xy,
                       use_gnn=config.use_gnn)
        pred = output.permute(0, 2, 3, 1).contiguous()
        pred = pred.view(-1, nclasses)
        pred = softmax(pred)
        pred_max_val, pred_arg_max = pred.max(1)

        for i in range(1, len(idx_to_label)):
            current_show = pred_arg_max.view(4, 640, 480).permute(0, 2, 1)
            nt = current_show.numpy()
            origin_cur = input.permute(0, 3, 2, 1)[:, :, :,
                                                   0:3].numpy().astype(np.int)
            mask = np.equal(i, nt)
            res = origin_cur
            res[:, :, :, 1] = res[:, :, :, 1] + 0.4 * 255 * mask
            for j in range(args.batchsize):
                plt.imsave(
                    segmented_path +
                    str(5001 + batch_idx * args.batchsize + j) + '_' +
                    idx_to_label[i], res[j].astype(np.uint8))
예제 #4
0
def train_nn(dataset_path,
             hha_dir,
             save_models_dir,
             num_epochs=50,
             batch_size=4,
             from_last_check_point=False,
             check_point_prefix='checkpoint',
             start_epoch=0,
             pre_train_model='',
             notebook=False):
    progress = tqdm_notebook if notebook else tqdm
    logger.info('Loading data...')

    dataset_tr = nyudv2.Dataset(dataset_path,
                                hha_dir,
                                flip_prob=config.flip_prob,
                                crop_type='Random',
                                crop_size=config.crop_size)
    dataloader_tr = DataLoader(dataset_tr,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=config.workers_tr,
                               drop_last=False,
                               pin_memory=True)

    dataset_va = nyudv2.Dataset(dataset_path,
                                hha_dir,
                                flip_prob=0.0,
                                crop_type='Center',
                                crop_size=config.crop_size)
    dataloader_va = DataLoader(dataset_va,
                               batch_size=batch_size,
                               shuffle=False,
                               num_workers=config.workers_va,
                               drop_last=False,
                               pin_memory=True)

    if from_last_check_point:
        start_epoch, pre_train_model = find_last_check_point(
            save_models_dir, check_point_prefix)

    cv2.setNumThreads(config.workers_tr)

    logger.info('Preparing model...')
    model = Model(config.nclasses, config.mlp_num_layers, config.use_gpu)
    loss = nn.NLLLoss(reduce=not config.use_bootstrap_loss,
                      weight=torch.FloatTensor(config.class_weights))
    softmax = nn.Softmax(dim=1)
    log_softmax = nn.LogSoftmax(dim=1)

    if config.use_gpu:
        model = model.cuda()
        loss = loss.cuda()
        softmax = softmax.cuda()
        log_softmax = log_softmax.cuda()

    optimizer = torch.optim.Adam([{
        'params': model.decoder.parameters()
    }, {
        'params': model.gnn.parameters(),
        'lr': config.gnn_initial_lr
    }],
                                 lr=config.base_initial_lr,
                                 betas=config.betas,
                                 eps=config.eps,
                                 weight_decay=config.weight_decay)

    if config.lr_schedule_type == 'exp':

        def lambda_1(lambda_epoch):
            return pow((1 - ((lambda_epoch - 1) / num_epochs)),
                       config.lr_decay)

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lambda_1)
    elif config.lr_schedule_type == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=config.lr_decay, patience=config.lr_patience)
    else:
        logger.error('Bad scheduler')
        exit(1)

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    logger.info("Number of trainable parameters: %d", params)

    def get_current_learning_rates():
        learning_rates = []
        for param_group in optimizer.param_groups:
            learning_rates.append(param_group['lr'])
        return learning_rates

    def eval_set(dataloader):
        model.eval()

        with torch.no_grad():
            loss_sum = 0.0
            init_tensor_value = np.zeros(14**2)
            if config.use_gpu:
                confusion_matrix = torch.cuda.FloatTensor(init_tensor_value)
            else:
                confusion_matrix = torch.FloatTensor(init_tensor_value)

            start_time = time.time()

            for batch_idx, rgbd_label_xy in progress(enumerate(dataloader),
                                                     total=len(dataloader),
                                                     desc=f'Eval set'):
                x = rgbd_label_xy[0]
                xy = rgbd_label_xy[2]
                target = rgbd_label_xy[1].long()
                x = x.float()
                xy = xy.float()

                input = x.permute(0, 3, 1, 2).contiguous()
                xy = xy.permute(0, 3, 1, 2).contiguous()
                if config.use_gpu:
                    input = input.cuda()
                    xy = xy.cuda()
                    target = target.cuda()

                output = model(input,
                               gnn_iterations=config.gnn_iterations,
                               k=config.gnn_k,
                               xy=xy,
                               use_gnn=config.use_gnn)

                if config.use_bootstrap_loss:
                    loss_per_pixel = loss.forward(log_softmax(output.float()),
                                                  target)
                    topk, indices = torch.topk(
                        loss_per_pixel.view(output.size()[0], -1),
                        int((config.crop_size**2) * config.bootstrap_rate))
                    loss_ = torch.mean(topk)
                else:
                    loss_ = loss.forward(log_softmax(output.float()), target)
                loss_sum += loss_

                pred = output.permute(0, 2, 3, 1).contiguous()
                pred = pred.view(-1, config.nclasses)
                pred = softmax(pred)
                pred_max_val, pred_arg_max = pred.max(1)

                pairs = target.view(-1) * 14 + pred_arg_max.view(-1)
                for i in range(14**2):
                    cumu = pairs.eq(i).float().sum()
                    confusion_matrix[i] += cumu.item()

            sys.stdout.write(" - Eval time: {:.2f}s \n".format(time.time() -
                                                               start_time))
            loss_sum /= len(dataloader)

            confusion_matrix = confusion_matrix.cpu().numpy().reshape((14, 14))
            class_iou = np.zeros(14)
            confusion_matrix[0, :] = np.zeros(14)
            confusion_matrix[:, 0] = np.zeros(14)
            for i in range(1, 14):
                class_iou[i] = confusion_matrix[i, i] / (
                    np.sum(confusion_matrix[i, :]) +
                    np.sum(confusion_matrix[:, i]) - confusion_matrix[i, i])

        return loss_sum.item(), class_iou, confusion_matrix

    # Training parameter
    logger.info(f'Num_epochs: {num_epochs}')
    interval_to_show = 100

    train_losses = []
    eval_losses = []

    if pre_train_model:
        logger.info(f'Loading pre-train model {pre_train_model}... ')
        model.load_state_dict(torch.load(pre_train_model))
    else:
        logger.info('Starting training from scratch...')

    # Training
    for epoch in progress(range(start_epoch, num_epochs + 1), desc='Training'):
        batch_loss_avg = 0
        if config.lr_schedule_type == 'exp':
            scheduler.step(epoch)
        for batch_idx, rgbd_label_xy in progress(enumerate(dataloader_tr),
                                                 total=len(dataloader_tr),
                                                 desc=f'Epoch {epoch}'):
            x = rgbd_label_xy[0]
            target = rgbd_label_xy[1].long()
            xy = rgbd_label_xy[2]
            x = x.float()
            xy = xy.float()

            input = x.permute(0, 3, 1, 2).contiguous()
            input = input.type(torch.FloatTensor)

            if config.use_gpu:
                input = input.cuda()
                xy = xy.cuda()
                target = target.cuda()

            xy = xy.permute(0, 3, 1, 2).contiguous()

            optimizer.zero_grad()
            model.train()

            output = model(input,
                           gnn_iterations=config.gnn_iterations,
                           k=config.gnn_k,
                           xy=xy,
                           use_gnn=config.use_gnn)

            if config.use_bootstrap_loss:
                loss_per_pixel = loss.forward(log_softmax(output.float()),
                                              target)
                topk, indices = torch.topk(
                    loss_per_pixel.view(output.size()[0], -1),
                    int((config.crop_size**2) * config.bootstrap_rate))
                loss_ = torch.mean(topk)
            else:
                loss_ = loss.forward(log_softmax(output.float()), target)

            loss_.backward()
            optimizer.step()

            batch_loss_avg += loss_.item()

            if batch_idx % interval_to_show == 0 and batch_idx > 0:
                batch_loss_avg /= interval_to_show
                train_losses.append(batch_loss_avg)
                logger.info("E%dB%d Batch loss average: %s", epoch, batch_idx,
                            batch_loss_avg)
                print('\rEpoch:{}, Batch:{}, loss average:{}'.format(
                    epoch, batch_idx, batch_loss_avg))
                batch_loss_avg = 0

        batch_idx = len(dataloader_tr)
        logger.info("E%dB%d Saving model...", epoch, batch_idx)

        torch.save(
            model.state_dict(),
            os.path.join(
                save_models_dir,
                f'{check_point_prefix}{CHECK_POINT_SEP}{epoch!s}{MODELS_EXT}'))

        # Evaluation
        eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va)
        eval_losses.append(eval_loss)

        if config.lr_schedule_type == 'plateau':
            scheduler.step(eval_loss)
        print('Learning ...')
        logger.info("E%dB%d Def learning rate: %s", epoch, batch_idx,
                    get_current_learning_rates()[0])
        print('Epoch{} Def learning rate: {}'.format(
            epoch,
            get_current_learning_rates()[0]))
        logger.info("E%dB%d GNN learning rate: %s", epoch, batch_idx,
                    get_current_learning_rates()[1])
        print('Epoch{} GNN learning rate: {}'.format(
            epoch,
            get_current_learning_rates()[1]))
        logger.info("E%dB%d Eval loss: %s", epoch, batch_idx, eval_loss)
        print('Epoch{} Eval loss: {}'.format(epoch, eval_loss))
        logger.info("E%dB%d Class IoU:", epoch, batch_idx)
        print('Epoch{} Class IoU:'.format(epoch))
        for cl in range(14):
            logger.info("%+10s: %-10s" % (IDX_LABEL[cl], class_iou[cl]))
            print('{}:{}'.format(IDX_LABEL[cl], class_iou[cl]))
        logger.info("Mean IoU: %s", np.mean(class_iou[1:]))
        print("Mean IoU: %.2f" % np.mean(class_iou[1:]))
        logger.info("E%dB%d Confusion matrix:", epoch, batch_idx)
        logger.info(confusion_matrix)

    logger.info('Finished training!')
    logger.info('Saving trained model...')
    torch.save(model.state_dict(),
               os.path.join(save_models_dir, f'finish{MODELS_EXT}'))
    eval_loss, class_iou, confusion_matrix = eval_set(dataloader_va)
    logger.info('Eval loss: %s', eval_loss)
    logger.info('Class IoU:')
    for cl in range(14):
        logger.info("%+10s: %-10s" % (IDX_LABEL[cl], class_iou[cl]))
    logger.info(f'Mean IoU: {np.mean(class_iou[1:])}')