Ejemplo n.º 1
0
def train():
    train_data = ACNet_data.SUNRGBD(transform=transforms.Compose([ACNet_data.scaleNorm(),
                                                                   ACNet_data.RandomScale((1.0, 1.4)),
                                                                   ACNet_data.RandomHSV((0.9, 1.1),
                                                                                         (0.9, 1.1),
                                                                                         (25, 25)),
                                                                   ACNet_data.RandomCrop(image_h, image_w),
                                                                   ACNet_data.RandomFlip(),
                                                                   ACNet_data.ToTensor(),
                                                                   ACNet_data.Normalize()]),
                                     phase_train=True,
                                     data_dir=args.data_dir)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=False)

    num_train = len(train_data)

    if args.last_ckpt:
        model = ACNet_models_V1.ACNet(num_class=40, pretrained=False)
    else:
        model = ACNet_models_V1.ACNet(num_class=40, pretrained=True)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    CEL_weighted = utils.CrossEntropyLoss2d(weight=nyuv2_frq)
    model.train()
    model.to(device)
    CEL_weighted.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                momentum=args.momentum, weight_decay=args.weight_decay)

    global_step = 0

    if args.last_ckpt:
        global_step, args.start_epoch = load_ckpt(model, optimizer, args.last_ckpt, device)

    lr_decay_lambda = lambda epoch: args.lr_decay_rate ** (epoch // args.lr_epoch_per_decay)
    scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda)

    writer = SummaryWriter(args.summary_dir)

    for epoch in range(int(args.start_epoch), args.epochs):

        scheduler.step(epoch)
        local_count = 0
        last_count = 0
        end_time = time.time()
        if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch:
            save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch,
                      local_count, num_train)

        for batch_idx, sample in enumerate(train_loader):

            image = sample['image'].to(device)
            depth = sample['depth'].to(device)
            target_scales = [sample[s].to(device) for s in ['label', 'label2', 'label3', 'label4', 'label5']]
            optimizer.zero_grad()
            pred_scales = model(image, depth, args.checkpoint)
            loss = CEL_weighted(pred_scales, target_scales)
            loss.backward()
            optimizer.step()
            local_count += image.data.shape[0]
            global_step += 1
            if global_step % args.print_freq == 0 or global_step == 1:

                time_inter = time.time() - end_time
                count_inter = local_count - last_count
                print_log(global_step, epoch, local_count, count_inter,
                          num_train, loss, time_inter)
                end_time = time.time()

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step, bins='doane')
                grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True)
                writer.add_image('image', grid_image, global_step)
                grid_image = make_grid(depth[:3].clone().cpu().data, 3, normalize=True)
                writer.add_image('depth', grid_image, global_step)
                grid_image = make_grid(utils.color_label(torch.max(pred_scales[0][:3], 1)[1] + 1), 3, normalize=False,
                                       range=(0, 255))
                writer.add_image('Predicted label', grid_image, global_step)
                grid_image = make_grid(utils.color_label(target_scales[0][:3]), 3, normalize=False, range=(0, 255))
                writer.add_image('Groundtruth label', grid_image, global_step)
                writer.add_scalar('CrossEntropyLoss', loss.data, global_step=global_step)
                writer.add_scalar('Learning rate', scheduler.get_lr()[0], global_step=global_step)
                last_count = local_count

    save_ckpt(args.ckpt_dir, model, optimizer, global_step, args.epochs,
              0, num_train)

    print("Training completed ")
Ejemplo n.º 2
0
def train():
    train_dirs = [train_dir for train_dir in [args.train_dir, args.train_dir2] if train_dir is not None]
    train_data = ACNet_data.FreiburgForest(
        transform=transforms.Compose([
            ACNet_data.ScaleNorm(),
            # ACNet_data.RandomRotate((-13, 13)),
            # ACNet_data.RandomSkew((-0.05, 0.10)),
            ACNet_data.RandomScale((1.0, 1.4)),
            ACNet_data.RandomHSV((0.9, 1.1),
                                 (0.9, 1.1),
                                 (25, 25)),
            ACNet_data.RandomCrop(image_h, image_w),
            ACNet_data.RandomFlip(),
            ACNet_data.ToTensor(),
            ACNet_data.Normalize()
        ]),
        data_dirs=train_dirs,
        modal1_name=args.modal1,
        modal2_name=args.modal2,
    )

    valid_dirs = [valid_dir for valid_dir in [args.valid_dir, args.valid_dir2] if valid_dir is not None]
    valid_data = ACNet_data.FreiburgForest(
        transform=transforms.Compose([
            ACNet_data.ScaleNorm(),
            ACNet_data.ToTensor(),
            ACNet_data.Normalize()
        ]),
        data_dirs=valid_dirs,
        modal1_name=args.modal1,
        modal2_name=args.modal2,
    )

    '''
    # Split dataset into training and validation
    dataset_length = len(data)
    valid_split = 0.05  # tiny split due to the small size of the dataset
    valid_length = int(valid_split * dataset_length)
    train_length = dataset_length - valid_length
    train_data, valid_data = torch.utils.data.random_split(data, [train_length, valid_length])
    '''

    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=False)
    valid_loader = DataLoader(valid_data, batch_size=args.batch_size * 3, shuffle=False,
                              num_workers=1, pin_memory=False)

    # Initialize model
    if args.last_ckpt:
        model = ACNet_models_V1.ACNet(num_class=5, pretrained=False)
    else:
        model = ACNet_models_V1.ACNet(num_class=5, pretrained=True)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    model.train()
    model.to(device)

    # Initialize criterion, optimizer and scheduler
    criterion = utils.CrossEntropyLoss2d(weight=freiburgforest_frq)
    criterion.to(device)

    # TODO: try with different optimizers and schedulers (CyclicLR exp_range for example)
    # TODO: try with a smaller LR (currently loss decay is too steep and then doesn't change)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                momentum=args.momentum, weight_decay=args.weight_decay)
    # lr_decay_lambda = lambda epoch: args.lr_decay_rate ** (epoch // args.lr_epoch_per_decay)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_decay_lambda)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=int(np.ceil(args.epochs / 7)), T_mult=2, eta_min=8e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=5e-4)
    global_step = 0

    # TODO: add early stop to avoid overfitting

    # Continue training from previous checkpoint
    if args.last_ckpt:
        global_step, args.start_epoch = utils.load_ckpt(model, optimizer, scheduler, args.last_ckpt, device)

    writer = SummaryWriter(args.summary_dir)
    losses = []
    for epoch in tqdm(range(int(args.start_epoch), args.epochs)):

        if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch:
            utils.save_ckpt(args.ckpt_dir, model, optimizer, scheduler, global_step, epoch)

        for batch_idx, sample in enumerate(train_loader):
            modal1, modal2 = sample['modal1'].to(device), sample['modal2'].to(device)
            target_scales = [sample[s].to(device) for s in ['label', 'label2', 'label3', 'label4', 'label5']]

            optimizer.zero_grad()
            pred_scales = model(modal1, modal2, args.checkpoint)
            loss = criterion(pred_scales, target_scales)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            global_step += 1
            if global_step % args.print_freq == 0 or global_step == 1:

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param.detach().cpu().numpy(), global_step, bins='doane')

                grid_image = make_grid(modal1[:3].detach().cpu(), 3, normalize=False)
                writer.add_image('Modal1', grid_image, global_step)
                grid_image = make_grid(modal2[:3].detach().cpu(), 3, normalize=False)
                writer.add_image('Modal2', grid_image, global_step)
                grid_image = make_grid(utils.color_label(torch.argmax(pred_scales[0][:3], 1) + 1), 3, normalize=True,
                                       range=(0, 255))
                writer.add_image('Prediction', grid_image, global_step)
                grid_image = make_grid(utils.color_label(target_scales[0][:3]), 3, normalize=True, range=(0, 255))
                writer.add_image('GroundTruth', grid_image, global_step)
                writer.add_scalar('Loss', loss.item(), global_step=global_step)
                writer.add_scalar('Loss average', sum(losses) / len(losses), global_step=global_step)
                writer.add_scalar('Learning rate', scheduler.get_last_lr()[0], global_step=global_step)

                # Compute validation metrics
                with torch.no_grad():
                    model.eval()

                    losses_val = []
                    acc_list = []
                    iou_list = []
                    for sample_val in valid_loader:
                        modal1_val, modal2_val = sample_val['modal1'].to(device), sample_val['modal2'].to(device)
                        target_val = sample_val['label'].to(device)
                        pred_val = model(modal1_val, modal2_val)

                        losses_val.append(criterion([pred_val], [target_val]).item())
                        acc_list.append(utils.accuracy(
                            (torch.argmax(pred_val, 1) + 1).detach().cpu().numpy().astype(int),
                            target_val.detach().cpu().numpy().astype(int))[0])
                        iou_list.append(utils.compute_IoU(
                            y_pred=(torch.argmax(pred_val, 1) + 1).detach().cpu().numpy().astype(int),
                            y_true=target_val.detach().cpu().numpy().astype(int),
                            num_classes=5
                        ))

                    writer.add_scalar('Loss validation', sum(losses_val) / len(losses_val), global_step=global_step)
                    writer.add_scalar('Accuracy', sum(acc_list) / len(acc_list), global_step=global_step)
                    iou = np.mean(np.stack(iou_list, axis=0), axis=0)
                    writer.add_scalar('IoU_Road', iou[0], global_step=global_step)
                    writer.add_scalar('IoU_Grass', iou[1], global_step=global_step)
                    writer.add_scalar('IoU_Vegetation', iou[2], global_step=global_step)
                    writer.add_scalar('IoU_Sky', iou[3], global_step=global_step)
                    writer.add_scalar('IoU_Obstacle', iou[4], global_step=global_step)
                    writer.add_scalar('mIoU', np.mean(iou), global_step=global_step)

                    model.train()

                losses = []

        scheduler.step()

    utils.save_ckpt(args.ckpt_dir, model, optimizer, scheduler, global_step, args.epochs)
    print("Training completed ")