Exemple #1
0
def main():
    args = get_args()
    log_folder = os.path.join('train_log', args.name)
    writer = SummaryWriter(log_folder)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # number of classes for each dataset.
    if args.dataset == 'PascalVOC':
        num_classes = 21
    elif args.dataset == 'COCO':
        num_classes = 81
    else:
        raise Exception("No dataset named {}.".format(args.dataset))

    # Select Model & Method
    model = models.__dict__[args.arch](num_classes=num_classes)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    # Optimizer
    optimizer = torch.optim.SGD(
        [{
            'params': get_parameters(model, bias=False, final=False),
            'lr': args.lr,
            'weight_decay': args.wd
        }, {
            'params': get_parameters(model, bias=True, final=False),
            'lr': args.lr * 2,
            'weight_decay': 0
        }, {
            'params': get_parameters(model, bias=False, final=True),
            'lr': args.lr * 10,
            'weight_decay': args.wd
        }, {
            'params': get_parameters(model, bias=True, final=True),
            'lr': args.lr * 20,
            'weight_decay': 0
        }],
        momentum=args.momentum)

    if args.resume:
        model = load_model(model, args.resume)

    train_loader = data_loader(args)
    data_iter = iter(train_loader)
    train_t = tqdm(range(args.max_iter))
    model.train()
    for global_iter in train_t:
        try:
            images, target, gt_map = next(data_iter)
        except:
            data_iter = iter(data_loader(args))
            images, target, gt_map = next(data_iter)

        if args.gpu is not None:
            images = images.cuda(args.gpu)
            gt_map = gt_map.cuda(args.gpu)
            target = target.cuda(args.gpu)

        output = model(images)

        fc8_SEC_softmax = softmax_layer(output)
        loss_s = seed_loss_layer(fc8_SEC_softmax, gt_map)
        loss_e = expand_loss_layer(fc8_SEC_softmax, target, num_classes - 1)
        fc8_SEC_CRF_log = crf_layer(output, images, iternum=10)
        loss_c = constrain_loss_layer(fc8_SEC_softmax, fc8_SEC_CRF_log)

        loss = loss_s + loss_e + loss_c

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # writer add_scalars
        writer.add_scalar('loss', loss, global_iter)
        writer.add_scalars('losses', {
            'loss_s': loss_s,
            'loss_e': loss_e,
            'loss_c': loss_c
        }, global_iter)

        with torch.no_grad():
            if global_iter % 10 == 0:
                # writer add_images (origin, output, gt)
                origin = images.clone().detach() + torch.tensor(
                    [123., 117., 107.]).reshape(1, 3, 1, 1).cuda(args.gpu)

                size = (100, 100)
                origin = F.interpolate(origin, size=size)
                origins = vutils.make_grid(origin,
                                           nrow=15,
                                           padding=2,
                                           normalize=True,
                                           scale_each=True)

                outputs = F.interpolate(output, size=size)
                _, outputs = torch.max(outputs, dim=1)
                outputs = outputs.unsqueeze(1)
                outputs = vutils.make_grid(outputs,
                                           nrow=15,
                                           padding=2,
                                           normalize=True,
                                           scale_each=True).float()

                gt_maps = F.interpolate(gt_map, size=size)
                _, gt_maps = torch.max(gt_maps, dim=1)
                gt_maps = gt_maps.unsqueeze(1)
                gt_maps = vutils.make_grid(gt_maps,
                                           nrow=15,
                                           padding=2,
                                           normalize=True,
                                           scale_each=True).float()

                # gt_maps = F.interpolate(gt_map.unsqueeze(1).float(), size=size)
                # gt_maps = vutils.make_grid(gt_maps, nrow=15, padding=2, normalize=True, scale_each=True).float()

                grid_image = torch.cat((origins, outputs, gt_maps), dim=1)
                writer.add_image(args.name, grid_image, global_iter)


        description = '[{0:4d}/{1:4d}] loss: {2} s: {3} e: {4} c: {5}'.\
            format(global_iter+1, args.max_iter, loss, loss_s, loss_e, loss_c)
        train_t.set_description(desc=description)

        # save snapshot
        if global_iter % args.snapshot == 0:
            save_checkpoint(model.state_dict(), log_folder,
                            'checkpoint_%d.pth.tar' % global_iter)

        # lr decay
        if global_iter % args.lr_decay == 0:
            args.lr = args.lr * 0.1
            optimizer = adjust_learning_rate(optimizer, args.lr)

    print("Training is over...")
    save_checkpoint(model.state_dict(), log_folder, 'last_checkpoint.pth.tar')
Exemple #2
0
def test(epoch):
    global best_accuracy
    global best_miou
    model.eval()
    test_loss_iou = 0
    test_loss_vec = 0
    hist = np.zeros((config["task1_classes"], config["task1_classes"]))
    hist_angles = np.zeros((config["task2_classes"], config["task2_classes"]))
    crop_size = config["val_dataset"][args.dataset]["crop_size"]
    for i, (inputsBGR, labels, vecmap_angles) in enumerate(val_loader, 0):
        inputsBGR = Variable(inputsBGR.float().cuda(),
                             volatile=True,
                             requires_grad=False)

        outputs, pred_vecmaps = model(inputsBGR)
        if args.multi_scale_pred:
            loss1 = road_loss(outputs[0],
                              util.to_variable(labels[0], True, False), True)
            num_stacks = model.module.num_stacks if num_gpus > 1 else model.num_stacks
            for idx in range(num_stacks - 1):
                loss1 += road_loss(outputs[idx + 1],
                                   util.to_variable(labels[0], True, False),
                                   True)
            for idx, output in enumerate(outputs[-2:]):
                loss1 += road_loss(
                    output, util.to_variable(labels[idx + 1], True, False),
                    True)

            loss2 = angle_loss(pred_vecmaps[0],
                               util.to_variable(vecmap_angles[0], True, False))
            for idx in range(num_stacks - 1):
                loss2 += angle_loss(
                    pred_vecmaps[idx + 1],
                    util.to_variable(vecmap_angles[0], True, False))
            for idx, pred_vecmap in enumerate(pred_vecmaps[-2:]):
                loss2 += angle_loss(
                    pred_vecmap,
                    util.to_variable(vecmap_angles[idx + 1], True, False))

            outputs = outputs[-1]
            pred_vecmaps = pred_vecmaps[-1]
        else:
            loss1 = road_loss(outputs,
                              util.to_variable(labels[0], True, False), True)
            loss2 = angle_loss(pred_vecmaps,
                               util.to_variable(labels[0], True, False))

        test_loss_iou += loss1.data[0]
        test_loss_vec += loss2.data[0]

        _, predicted = torch.max(outputs.data, 1)

        correctLabel = labels[-1].view(-1, crop_size, crop_size).long()
        hist += util.fast_hist(
            predicted.view(predicted.size(0), -1).cpu().numpy(),
            correctLabel.view(correctLabel.size(0), -1).cpu().numpy(),
            config["task1_classes"],
        )

        _, predicted_angle = torch.max(pred_vecmaps.data, 1)
        correct_angles = vecmap_angles[-1].view(-1, crop_size,
                                                crop_size).long()
        hist_angles += util.fast_hist(
            predicted_angle.view(predicted_angle.size(0), -1).cpu().numpy(),
            correct_angles.view(correct_angles.size(0), -1).cpu().numpy(),
            config["task2_classes"],
        )

        p_accu, miou, road_iou, fwacc = util.performMetrics(
            train_loss_file,
            val_loss_file,
            epoch,
            hist,
            test_loss_iou / (i + 1),
            test_loss_vec / (i + 1),
            is_train=False,
        )
        p_accu_angle, miou_angle, fwacc_angle = util.performAngleMetrics(
            train_loss_angle_file,
            val_loss_angle_file,
            epoch,
            hist_angles,
            is_train=False)

        viz_util.progress_bar(
            i,
            len(val_loader),
            "Loss: %.6f | VecLoss: %.6f | road miou: %.4f%%(%.4f%%) | angle miou: %.4f%%"
            % (
                test_loss_iou / (i + 1),
                test_loss_vec / (i + 1),
                miou,
                road_iou,
                miou_angle,
            ),
        )

        if i % 100 == 0 or i == len(val_loader) - 1:
            images_path = "{}/images/".format(experiment_dir)
            util.ensure_dir(images_path)
            util.savePredictedProb(
                inputsBGR.data.cpu(),
                labels[-1].cpu(),
                predicted.cpu(),
                F.softmax(outputs, dim=1).data.cpu()[:, 1, :, :],
                predicted_angle.cpu(),
                os.path.join(images_path,
                             "validate_pair_{}_{}.png".format(epoch, i)),
                norm_type=config["val_dataset"]["normalize_type"],
            )

        del inputsBGR, labels, predicted, outputs, pred_vecmaps, predicted_angle

    accuracy, miou, road_iou, fwacc = util.performMetrics(
        train_loss_file,
        val_loss_file,
        epoch,
        hist,
        test_loss_iou / len(val_loader),
        test_loss_vec / len(val_loader),
        is_train=False,
        write=True,
    )
    util.performAngleMetrics(
        train_loss_angle_file,
        val_loss_angle_file,
        epoch,
        hist_angles,
        is_train=False,
        write=True,
    )

    if miou > best_miou:
        best_accuracy = accuracy
        best_miou = miou
        util.save_checkpoint(epoch, test_loss_iou / len(val_loader), model,
                             optimizer, best_accuracy, best_miou, config,
                             experiment_dir)

    return test_loss_iou / len(val_loader)
Exemple #3
0
            logger.info('epoch:{},{}/{}      learning_rate:{:}'.format(
                epoch, i_ter - epoch * len(train_loader), len(train_loader),
                lr))
            logger.info('loss:{:.2f}     accuracy:{}'.format(
                losses.val, batch_acc))
            logger.info('batch_time{:.2f},totle time{:.2f}'.format(
                batch_time.val, batch_time.sum))
            logger.info('data_time:{:.2f}'.format(data_time.avg))
            print('out_shape:{}'.format(preds.shape))
            parsing_acc = AverageMeter()
# information for epoch
    epoch_time.update(time.time() - epoch_end)
    epoch_end = time.time()
    logger.info('*******')
    logger.info('')
    logger.info('epoch:{}   total_time:{:.2f}.'.format(epoch, epoch_time.sum))
    logger.info('train_loss:{:.2f}'.format(losses.avg))
    # save model for every epoch
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'model': config.MODEL.NAME,
            'state_dict': model.state_dict(),  # for resume
            'module_state_dict': model.module.state_dict(),  # for eval
            'optimizer': optimizer.state_dict(),
        },
        out_dir,
        filename='checkpoint_parallel_deconv.pth')
    logger.info('save checkpoint in {}'.format(out_dir))
logger.info('finish')
Exemple #4
0
            for item in loss_list:
                s += (item + '_loss:%.4f,' % step_vals[item])
            print(s[:-1])
            s = ''
            for item in acc_type_list:
                acc_record[item] = np.mean(
                    np.array([
                        modelopera.accuracy(algorithm, eval_loaders[i])
                        for i in eval_name_dict[item]
                    ]))
                s += (item + '_acc:%.4f,' % acc_record[item])
            print(s[:-1])
            if acc_record['valid'] > best_valid_acc:
                best_valid_acc = acc_record['valid']
                target_acc = acc_record['target']
            if args.save_model_every_checkpoint:
                save_checkpoint(f'model_epoch{epoch}.pkl', algorithm, args)
            print('total cost time: %.4f' % (time.time() - sss))
            algorithm_dict = algorithm.state_dict()

    save_checkpoint('model.pkl', algorithm, args)

    print('valid acc: %.4f' % best_valid_acc)
    print('DG result: %.4f' % target_acc)

    with open(os.path.join(args.output, 'done.txt'), 'w') as f:
        f.write('done\n')
        f.write('total cost time:%s\n' % (str(time.time() - sss)))
        f.write('valid acc:%.4f\n' % (best_valid_acc))
        f.write('target acc:%.4f' % (target_acc))
Exemple #5
0
def main():
    args = parser.parse_args()

    # Create dataset
    print("=> creating dataset")

    device = torch.device('cuda')

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = Talk2Car(root=args.root,
                             split='train',
                             transform=transforms.Compose(
                                 [transforms.ToTensor(), normalize]))
    val_dataset = Talk2Car(root=args.root,
                           split='val',
                           transform=transforms.Compose(
                               [transforms.ToTensor(), normalize]))

    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.workers,
                                       pin_memory=True,
                                       drop_last=True)
    val_dataloader = data.DataLoader(val_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=False)

    # Create model
    print("=> creating model")
    img_encoder = nn.DataParallel(
        EfficientNet.from_pretrained('efficientnet-b2'))
    text_encoder = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')
    fc_model = nn.Sequential(nn.Linear(1024, 1000), nn.ReLU(),
                             nn.Linear(1000, 1000))

    fc_model.to(device)
    img_encoder.to(device)
    text_encoder.to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.ignore_index,
                                    reduction='mean')
    criterion.to(device)

    cudnn.benchmark = True

    # Optimizer and scheduler
    print("=> creating optimizer and scheduler")
    params = list(img_encoder.parameters()) + list(fc_model.parameters())
    optimizer = optim.SGD(params,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay,
                          nesterov=args.nesterov)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    # Checkpoint
    checkpoint = 'checkpoint.pth.tar'
    if os.path.exists(checkpoint):
        print("=> resume from checkpoint at %s" % (checkpoint))
        checkpoint = torch.load(checkpoint, map_location='cpu')
        img_encoder.load_state_dict(checkpoint['img_encoder'])
        fc_model.load_state_dict(checkpoint['fc_model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        start_epoch = checkpoint['epoch']
        best_ap50 = checkpoint['best_ap50']
    else:
        print("=> no checkpoint at %s" % (checkpoint))
        best_ap50 = 0
        start_epoch = 0

    # Start training
    print("=> start training")

    for epoch in range(start_epoch, args.epochs):
        print('Start epoch %d/%d' % (epoch, args.epochs))
        print(20 * '-')

        # Train
        train(train_dataloader, img_encoder, text_encoder, fc_model, optimizer,
              criterion, epoch, args)

        # Update lr rate
        scheduler.step()

        # Evaluate
        ap50 = evaluate(val_dataloader, img_encoder, text_encoder, fc_model,
                        args)
        print("AP50:", ap50)

        # Checkpoint
        if ap50 > best_ap50:
            new_best = True
            best_ap50 = ap50
        else:
            new_best = False

        save_checkpoint(
            {
                'img_encoder': img_encoder.state_dict(),
                'fc_model': fc_model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'epoch': epoch + 1,
                'best_ap50': best_ap50
            },
            new_best=new_best)

    # Evaluate
    if args.evaluate:
        print("=> Evaluating best model")
        checkpoint = torch.load('best_model.pth.tar', map_location='cpu')
        img_encoder.load_state_dict(checkpoint['img_encoder'])
        fc_model.load_state_dict(checkpoint['fc_model'])
        ap50 = evaluate(val_dataloader, img_encoder, text_encoder, fc_model,
                        args)
        print('AP50 on validation set is %.2f' % (ap50 * 100))
Exemple #6
0
    print('\nEvaluation:')
    print(
        '\tVal Acc of Multi-view: %.2f - Val Acc of Single view: %.2f - Loss: %.4f'
        % (avg_test_acc.item(), avg_test_acc_single.item(), avg_loss.item()))
    print('\tCurrent best val acc: %.2f' % best_acc)

    # Log epoch to tensorboard
    # See log using: tensorboard --logdir='logs' --port=6006
    logEpoch(logger, model, epoch + 1, avg_loss, avg_test_acc)

    # Save model
    if avg_test_acc > best_acc:
        print('\tSaving checkpoint - Acc: %.2f' % avg_test_acc)
        best_acc = avg_test_acc
        best_loss = avg_loss
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': avg_test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, args.model)

    # Decaying Learning Rate
    if (epoch + 1) % args.lr_decay_freq == 0:
        lr *= args.lr_decay
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        print('Learning rate:', lr)
Exemple #7
0
def train(args):
    epochs = 350
    batch_size = 288
    util.set_seeds(args.rank)
    model = nn.EfficientNet().cuda()
    lr = batch_size * torch.cuda.device_count() * 0.256 / 4096
    optimizer = nn.RMSprop(util.add_weight_decay(model), lr, 0.9, 1e-3, momentum=0.9)
    ema = nn.EMA(model)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
    else:
        model = torch.nn.DataParallel(model)
    criterion = nn.CrossEntropyLoss().cuda()
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                   transforms.Compose([util.RandomResize(),
                                                       transforms.ColorJitter(0.4, 0.4, 0.4),
                                                       transforms.RandomHorizontalFlip(),
                                                       util.RandomAugment(),
                                                       transforms.ToTensor(), normalize]))
    if args.distributed:
        sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    else:
        sampler = None

    loader = data.DataLoader(dataset, batch_size, sampler=sampler, num_workers=8, pin_memory=True)

    scheduler = nn.StepLR(optimizer)
    amp_scale = torch.cuda.amp.GradScaler()
    with open(f'weights/{scheduler.__str__()}.csv', 'w') as f:
        if args.local_rank == 0:
            writer = csv.DictWriter(f, fieldnames=['epoch', 'acc@1', 'acc@5'])
            writer.writeheader()
        best_acc1 = 0
        for epoch in range(0, epochs):
            if args.distributed:
                sampler.set_epoch(epoch)
            if args.local_rank == 0:
                print(('\n' + '%10s' * 2) % ('epoch', 'loss'))
                bar = tqdm.tqdm(loader, total=len(loader))
            else:
                bar = loader
            model.train()
            for images, target in bar:
                loss = batch(images, target, model, criterion)
                optimizer.zero_grad()
                amp_scale.scale(loss).backward()
                amp_scale.step(optimizer)
                amp_scale.update()

                ema.update(model)
                torch.cuda.synchronize()
                if args.local_rank == 0:
                    bar.set_description(('%10s' + '%10.4g') % ('%g/%g' % (epoch + 1, epochs), loss))

            scheduler.step(epoch + 1)
            if args.local_rank == 0:
                acc1, acc5 = test(ema.model.eval())
                writer.writerow({'acc@1': str(f'{acc1:.3f}'),
                                 'acc@5': str(f'{acc5:.3f}'),
                                 'epoch': str(epoch + 1).zfill(3)})
                util.save_checkpoint({'state_dict': ema.model.state_dict()}, acc1 > best_acc1)
                best_acc1 = max(acc1, best_acc1)
    if args.distributed:
        torch.distributed.destroy_process_group()
    torch.cuda.empty_cache()
Exemple #8
0
def train(gpu, ngpus_per_node, argss):
    global args
    args = argss

    if args.arch == 'triple':
        model = TriSeNet(layers=args.layers, classes=args.classes)
        modules_ori = [
            model.layer0, model.layer1, model.layer2, model.layer3,
            model.layer4
        ]
        # modules_new = [model.down_8_32, model.sa_8_32, model.seg_head]
        modules_new = []
        for key, value in model._modules.items():
            if "layer" not in key:
                modules_new.append(value)
    args.index_split = len(
        modules_ori
    )  # the module after index_split need multiply 10 at learning rate
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    print("=> creating model ...")
    print("Classes: {}".format(args.classes))
    print(model)
    model = torch.nn.DataParallel(model.cuda())
    cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda(gpu)

    value_scale = 255
    ## RGB mean & std
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=mean,
                       ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)
    ])

    train_data = SemData(split='train',
                         data_root=args.data_root,
                         data_list=args.train_list,
                         transform=train_transform)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, \
        shuffle=True, num_workers=args.workers, pin_memory=True, \
        sampler=None, drop_last=True)

    # Training Loop
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = args.max_iter

    data_iter = iter(train_loader)
    epoch = 0
    for current_iter in range(args.start_iter, args.max_iter):
        try:
            input, target = next(data_iter)
            if not input.size(0) == args.batch_size:
                raise StopIteration
        except StopIteration:
            epoch += 1
            data_iter = iter(train_loader)
            input, target = next(data_iter)
            # need to update the AverageMeter for new epoch
            main_loss_meter = AverageMeter()
            aux_loss_meter = AverageMeter()
            loss_meter = AverageMeter()
            intersection_meter = AverageMeter()
            union_meter = AverageMeter()
            target_meter = AverageMeter()
        # measure data loading time
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        main_out = model(input)
        main_loss = criterion(main_out, target)
        aux_loss = torch.tensor(0).cuda()
        loss = main_loss + args.aux_weight * aux_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        n = input.size(0)
        main_out = main_out.detach().max(1)[1]
        intersection, union, target = intersectionAndUnionGPU(
            main_out, target, args.classes, args.ignore_label)
        intersection, union, target = intersection.cpu().numpy(), union.cpu(
        ).numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)

        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        # Using Poly strategy to change the learning rate
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split
                           ):  # args.index_split = 5 -> ResNet has 5 stages
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10

        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        iter_log = current_iter + 1

        if iter_log % args.print_freq == 0:
            print('Iteration: [{}/{}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                  'ETA {remain_time} '
                  'MainLoss {main_loss_meter.val:.4f} '
                  'AuxLoss {aux_loss_meter.val:.4f} '
                  'Loss {loss_meter.val:.4f} '
                  'Accuracy {accuracy:.4f}.'.format(
                      iter_log,
                      args.max_iter,
                      data_time=data_time,
                      batch_time=batch_time,
                      remain_time=remain_time,
                      main_loss_meter=main_loss_meter,
                      aux_loss_meter=aux_loss_meter,
                      loss_meter=loss_meter,
                      accuracy=accuracy))
    save_checkpoint(
        {
            'iteration': iter_log,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, False, args.save_path)
def main_worker(local_rank, ngpus_per_node, argss):
    global args
    args = argss

    dist.init_process_group(backend=args.dist_backend)

    teacher_model = None
    if args.teacher_model_path:
        teacher_model = PSPNet(layers=args.teacher_layers, classes=args.classes, zoom_factor=args.zoom_factor)
        kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature)
        args.save_path = os.path.join(args.save_path, kd_path)
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
    if args.arch == 'psp':
        model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor)
        modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        modules_new = [model.ppm, model.cls, model.aux]
    elif args.arch == 'bise_v1':
        model = BiseNet(num_classes=args.classes)
        modules_ori = [model.sp, model.cp]
        modules_new = [model.ffm, model.conv_out, model.conv_out16, model.conv_out32]
    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10))
    args.index_split = 5
    optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if teacher_model is not None:
            teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path) # tensorboardX
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
        if teacher_model is not None:
            logger.info(teacher_model)
    if args.distributed:
        torch.cuda.set_device(local_rank)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank])
        if teacher_model is not None:
            teacher_model = torch.nn.parallel.DistributedDataParallel(teacher_model.cuda(), device_ids=[local_rank])

    else:
        model = torch.nn.DataParallel(model.cuda())
        if teacher_model is not None:
            teacher_model = torch.nn.DataParallel(teacher_model.cuda())
    
    if teacher_model is not None:
        checkpoint = torch.load(args.teacher_model_path, map_location=lambda storage, loc: storage.cuda())
        teacher_model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loading teacher checkpoint '{}'".format(args.teacher_model_path))
    
    criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda(local_rank)
    kd_criterion = None
    if teacher_model is not None:
        kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(local_rank)
            
    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight: '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> mp weight found at '{}'".format(args.weight))
    
    best_mIoU_val = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # Load all tensors onto GPU
            checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_mIoU_val = checkpoint['best_mIoU_val']
            if main_process():
                logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['point']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(args.resume))    
        
    value_scale = 255
    mean = [0.485, 0.456, 0.406]
    mean = [item * value_scale for item in mean]
    std = [0.229, 0.224, 0.225]
    std = [item * value_scale for item in std]
        
    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
        transform.ToTensor(),
        transform.Normalize(mean=mean, std=std)])

    train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    if args.evaluate:
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std)])
        val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    for epoch in range(args.start_epoch, args.epochs):
        epoch_log = epoch + 1
        if args.distributed:
            # Use .set_epoch() method to reshuffle the dataset partition at every iteration
            train_sampler.set_epoch(epoch)
        loss_train, mIoU_train, mAcc_train, allAcc_train = train(local_rank, train_loader, model, teacher_model, criterion, kd_criterion, optimizer, epoch)
        if main_process():
            writer.add_scalar('loss_train', loss_train, epoch_log)
            writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
            writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
            writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
        
        is_best = False
        if args.evaluate:
            loss_val, mIoU_val, mAcc_val, allAcc_val = validate(local_rank, val_loader, model, criterion)
            if main_process():
                writer.add_scalar('loss_val', loss_val, epoch_log)
                writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
                writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
                writer.add_scalar('allAcc_val', allAcc_val, epoch_log)

                if best_mIoU_val < mIoU_val:
                    is_best = True
                    best_mIoU_val = mIoU_val
                    logger.info('==>The best val mIoU: %.3f' % (best_mIoU_val))

        
        if (epoch_log % args.save_freq == 0) and main_process():
            save_checkpoint(
                {
                    'epoch': epoch_log, 
                    'state_dict': model.state_dict(), 
                    'optimizer': optimizer.state_dict(),
                    'best_mIoU_val': best_mIoU_val
                }, 
                is_best, 
                args.save_path
            )
            if is_best:
                logger.info('Saving checkpoint to:' + args.save_path + '/best.pth with mIoU: ' + str(best_mIoU_val) )
            else:
                logger.info('Saving checkpoint to:' + args.save_path + '/last.pth with mIoU: ' + str(mIoU_val) )

    if main_process():  
        writer.close() # it must close the writer, otherwise it will appear the EOFError!
        logger.info('==>Training done!\nBest mIoU: %.3f' % (best_mIoU_val))
Exemple #10
0
def main():
    global args, best_EPE
    args = parser.parse_args()
    save_path = '{},{},b{},lr{}'.format(args.arch, args.solver,
                                        args.batch_size, args.lr)
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        save_path = os.path.join(timestamp, save_path)
    save_path = os.path.join(args.dataset, save_path)
    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    train_writer = SummaryWriter(os.path.join(save_path, 'train'))
    test_writer = SummaryWriter(os.path.join(save_path, 'test'))
    output_writers = []
    for i in range(3):
        output_writers.append(
            SummaryWriter(os.path.join(save_path, 'test', str(i))))

    # Data loading code
    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]),
        transforms.Normalize(mean=[0.411, 0.432, 0.45], std=[1, 1, 1])
    ])
    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow])
    ])

    if 'KITTI' in args.dataset:
        args.sparse = True
    if args.sparse:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])
    else:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10, 5),
            flow_transforms.RandomCrop((320, 448)),
            flow_transforms.RandomVerticalFlip(),
            flow_transforms.RandomHorizontalFlip()
        ])

    print("=> fetching img pairs in '{}'".format(args.data))
    train_set, test_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=args.split_file if args.split_file else args.split_value)
    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True,
                                             shuffle=False)

    # create model
    if args.pretrained:
        network_data = torch.load(args.pretrained)
        args.arch = network_data['arch']
        print("=> using pre-trained model '{}'".format(args.pretrained))
    else:
        network_data = None
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch](network_data)
    bias_params = model.bias_parameters()
    weight_params = model.weight_parameters()
    parallel = False

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
            parallel = True
        model = model.to(device)
        cudnn.benchmark = True

        if parallel is True:
            bias_params = model.module.bias_parameters()
            weight_params = model.module.weight_parameters()

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    param_groups = [{
        'params': bias_params,
        'weight_decay': args.bias_decay
    }, {
        'params': weight_params,
        'weight_decay': args.weight_decay
    }]
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(param_groups,
                                     args.lr,
                                     betas=(args.momentum, args.beta))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(param_groups,
                                    args.lr,
                                    momentum=args.momentum)

    if args.evaluate:
        best_EPE = validate(val_loader, model, 0, output_writers)
        return

    epoch_size = len(train_loader)
    # device_num = torch.cuda.device_count() if torch.cuda.is_available() else 1
    # epochs = 70000 // device_num // epoch_size + 1 if args.epochs == -1 else args.epochs
    epochs = 70000 // epoch_size + 1 if args.epochs == -1 else args.epochs

    for epoch in range(args.start_epoch, epochs):
        adjust_learning_rate(args, optimizer, epoch, epoch_size)

        # train for one epoch
        train_loss, train_EPE = train(train_loader, model, optimizer, epoch,
                                      epoch_size, train_writer)
        train_writer.add_scalar('mean EPE', train_EPE, epoch)

        # evaluate on validation set
        with torch.no_grad():
            EPE = validate(val_loader, model, epoch, output_writers)
        test_writer.add_scalar('mean EPE', EPE, epoch)

        if best_EPE < 0:
            best_EPE = EPE

        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.module.state_dict(),
                'best_EPE': best_EPE,
                'div_flow': args.div_flow
            }, is_best, save_path)
Exemple #11
0
def main():
    args = get_args()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # number of classes for each dataset.
    if args.dataset == 'PascalVOC':
        num_classes = 20
    else:
        raise Exception("No dataset named {}.".format(args.dataset))

    # Select Model & Method
    model = models.__dict__[args.arch](pretrained=args.pretrained,
                                       num_classes=num_classes)

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    # define loss function (criterion) and optimizer
    criterion = nn.MultiLabelSoftMarginLoss().cuda(args.gpu)
    # criterion = nn.BCEWithLogitsLoss().cuda(args.gpu)

    # Take apart parameters to give different Learning Rate
    param_features = []
    param_classifiers = []

    if args.arch.startswith('vgg'):
        for name, parameter in model.named_parameters():
            if 'features.' in name:
                param_features.append(parameter)
            else:
                param_classifiers.append(parameter)
    elif args.arch.startswith('resnet'):
        for name, parameter in model.named_parameters():
            if 'layer4.' in name or 'fc.' in name:
                param_classifiers.append(parameter)
            else:
                param_features.append(parameter)
    else:
        raise Exception("Fail to recognize the architecture")

    # Optimizer
    optimizer = torch.optim.SGD([
        {'params': param_features, 'lr': args.lr},
        {'params': param_classifiers, 'lr': args.lr * args.lr_ratio}],
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        nesterov=args.nest)

    # optionally resume from a checkpoint
    if args.resume:
        model, optimizer = load_model(model, optimizer, args)
    train_loader, val_loader, test_loader = data_loader(args)

    saving_dir = os.path.join(args.log_folder, args.name)

    if args.evaluate:
        # test_ap, test_loss = evaluate_cam(val_loader, model, criterion, args)
        # test_ap, test_loss = evaluate_cam2(val_loader, model, criterion, args)
        test_ap, test_loss = evaluate_cam3(val_loader, model, criterion, args)
        print_progress(test_ap, test_loss, 0, 0, prefix='test')
        return

    # Training Phase
    best_m_ap = 0
    for epoch in range(args.start_epoch, args.epochs):

        adjust_learning_rate(optimizer, epoch, args)

        # Train for one epoch
        train_ap, train_loss = \
            train(train_loader, model, criterion, optimizer, epoch, args)
        print_progress(train_ap, train_loss, epoch+1, args.epochs)

        # Evaluate classification
        val_ap, val_loss = validate(val_loader, model, criterion, epoch, args)
        print_progress(val_ap, val_loss, epoch+1, args.epochs, prefix='validation')

        # # Save checkpoint at best performance:
        is_best = val_ap.mean() > best_m_ap
        if is_best:
            best_m_ap = val_ap.mean()

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_m_ap': best_m_ap,
            'optimizer': optimizer.state_dict(),
        }, is_best, saving_dir)

        save_progress(saving_dir, train_ap, train_loss, val_ap, val_loss, args)
Exemple #12
0
def test(epoch):
    global best_accuracy
    global best_miou
    model.eval()
    test_loss_iou = 0
    test_loss_vec = 0
    hist = np.zeros((config["task1_classes"], config["task1_classes"]))
    crop_size = config["val_dataset"][args.dataset]["crop_size"]
    for i, datas in enumerate(val_loader, 0):
        inputs, labels, erased_label = data
        batch_size = inputs.size(0)

        inputs = Variable(inputs.float().cuda(),
                          volatile=True,
                          requires_grad=False)
        erased_label = Variable(erased_label[-1].float().cuda(),
                                volatile=True,
                                requires_grad=False).unsqueeze(dim=1)
        temp = erased_label

        for k in range(config['refinement']):
            in_ = torch.cat((inputs, erased_label, temp), dim=1)
            outputs = model(in_)
            if args.multi_scale_pred:
                loss1 = road_loss(outputs[0], labels[0].long().cuda(), False)
                num_stacks = model.module.num_stacks if num_gpus > 1 else model.num_stacks
                for idx in range(num_stacks - 1):
                    loss1 += road_loss(outputs[idx + 1],
                                       labels[0].long().cuda(), False)
                for idx, output in enumerate(outputs[-2:]):
                    loss1 += road_loss(output, labels[idx + 1].long().cuda(),
                                       False)

                outputs = outputs[-1]
            else:
                loss1 = road_loss(outputs, labels[-1].long().cuda(), False)

            temp = Variable(torch.max(outputs.data, 1)[1].float(),
                            volatile=True,
                            requires_grad=False).unsqueeze(dim=1)

        test_loss_iou += loss1.data[0]

        _, predicted = torch.max(outputs.data, 1)

        correctLabel = labels[-1].view(-1, crop_size, crop_size).long()
        hist += util.fast_hist(
            predicted.view(predicted.size(0), -1).cpu().numpy(),
            correctLabel.view(correctLabel.size(0), -1).cpu().numpy(),
            config["task1_classes"],
        )

        p_accu, miou, road_iou, fwacc = util.performMetrics(
            train_loss_file,
            val_loss_file,
            epoch,
            hist,
            test_loss_iou / (i + 1),
            0,
            is_train=False,
        )

        viz_util.progress_bar(
            i,
            len(val_loader),
            "Loss: %.6f | road miou: %.4f%%(%.4f%%)" % (
                test_loss_iou / (i + 1),
                miou,
                road_iou,
            ),
        )

        if i % 100 == 0 or i == len(val_loader) - 1:
            images_path = "{}/images/".format(experiment_dir)
            util.ensure_dir(images_path)
            util.savePredictedProb(
                inputsBGR.data.cpu(),
                labels[-1].cpu(),
                predicted.cpu(),
                F.softmax(outputs, dim=1).data.cpu()[:, 1, :, :],
                None,
                os.path.join(images_path,
                             "validate_pair_{}_{}.png".format(epoch, i)),
                norm_type=config["val_dataset"]["normalize_type"],
            )

        del inputsBGR, labels, predicted, outputs

    accuracy, miou, road_iou, fwacc = util.performMetrics(
        train_loss_file,
        val_loss_file,
        epoch,
        hist,
        test_loss_iou / len(val_loader),
        0,
        is_train=False,
        write=True,
    )

    if miou > best_miou:
        best_accuracy = accuracy
        best_miou = miou
        util.save_checkpoint(epoch, test_loss_iou / len(val_loader), model,
                             optimizer, best_accuracy, best_miou, config,
                             experiment_dir)

    return test_loss_iou / len(val_loader)
Exemple #13
0
def main():
    args = parser.parse_args()

    # Create dataset
    print("=> creating dataset")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = Talk2Car(root=args.root,
                             split='train',
                             transform=transforms.Compose(
                                 [transforms.ToTensor(), normalize]))
    val_dataset = Talk2Car(root=args.root,
                           split='val',
                           transform=transforms.Compose(
                               [transforms.ToTensor(), normalize]))

    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.workers,
                                       collate_fn=custom_collate,
                                       pin_memory=True,
                                       drop_last=True)
    val_dataloader = data.DataLoader(val_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=args.workers,
                                     collate_fn=custom_collate,
                                     pin_memory=True,
                                     drop_last=False)

    # Create model
    print("=> creating model")
    img_encoder = resnet.__dict__['resnet18'](pretrained=True)
    text_encoder = nlp_models.TextEncoder(
        input_dim=train_dataset.number_of_words(),
        hidden_size=512,
        dropout=0.1)
    img_encoder.cuda()
    text_encoder.cuda()

    criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.ignore_index,
                                    reduction='mean')
    criterion.cuda()

    cudnn.benchmark = True

    # Optimizer and scheduler
    print("=> creating optimizer and scheduler")
    params = list(text_encoder.parameters()) + list(img_encoder.parameters())
    optimizer = optim.SGD(params,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay,
                          nesterov=args.nesterov)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    # Checkpoint
    checkpoint = 'checkpoint.pth.tar'
    if os.path.exists(checkpoint):
        print("=> resume from checkpoint at %s" % (checkpoint))
        checkpoint = torch.load(checkpoint, map_location='cpu')
        img_encoder.load_state_dict(checkpoint['img_encoder'])
        text_encoder.load_state_dict(checkpoint['text_encoder'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        start_epoch = checkpoint['epoch']
        best_ap50 = checkpoint['best_ap50']
    else:
        print("=> no checkpoint at %s" % (checkpoint))
        best_ap50 = 0
        start_epoch = 0

    # Start training
    print("=> start training")

    for epoch in range(start_epoch, args.epochs):
        print('Start epoch %d/%d' % (epoch, args.epochs))
        print(20 * '-')

        # Train
        train(train_dataloader, img_encoder, text_encoder, optimizer,
              criterion, epoch, args)

        # Update lr rate
        scheduler.step()

        # Evaluate
        ap50 = evaluate(val_dataloader, img_encoder, text_encoder, args)

        # Checkpoint
        if ap50 > best_ap50:
            new_best = True
            best_ap50 = ap50
        else:
            new_best = False

        save_checkpoint(
            {
                'img_encoder': img_encoder.state_dict(),
                'text_encoder': text_encoder.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'epoch': epoch + 1,
                'best_ap50': best_ap50
            },
            new_best=new_best)

    # Evaluate
    if args.evaluate:
        print("=> Evaluating best model")
        checkpoint = torch.load('best_model.pth.tar', map_location='cpu')
        img_encoder.load_state_dict(checkpoint['img_encoder'])
        text_encoder.load_state_dict(checkpoint['text_encoder'])
        ap50 = evaluate(val_dataloader, img_encoder, text_encoder, args)
        print('AP50 on validation set is %.2f' % (ap50 * 100))
Exemple #14
0
def main():
    epochs = 450
    device = torch.device('cuda')
    data_dir = '../Dataset/IMAGENET'
    num_gpu = torch.cuda.device_count()
    v_batch_size = 16 * num_gpu
    t_batch_size = 256 * num_gpu

    model = nn.EfficientNet(num_class, version[0], version[1],
                            version[3]).to(device)
    optimizer = nn.RMSprop(util.add_weight_decay(model),
                           0.012 * num_gpu,
                           0.9,
                           1e-3,
                           momentum=0.9)

    model = torch.nn.DataParallel(model)
    _ = model(torch.zeros(1, 3, version[2], version[2]).to(device))

    ema = nn.EMA(model)
    t_criterion = nn.CrossEntropyLoss().to(device)
    v_criterion = torch.nn.CrossEntropyLoss().to(device)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    t_dataset = datasets.ImageFolder(
        os.path.join(data_dir, 'train'),
        transforms.Compose([
            util.RandomResize(version[2]),
            transforms.ColorJitter(0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ]))
    v_dataset = datasets.ImageFolder(
        os.path.join(data_dir, 'val'),
        transforms.Compose([
            transforms.Resize(version[2] + 32),
            transforms.CenterCrop(version[2]),
            transforms.ToTensor(), normalize
        ]))

    t_loader = data.DataLoader(t_dataset,
                               batch_size=t_batch_size,
                               shuffle=True,
                               num_workers=os.cpu_count(),
                               pin_memory=True)
    v_loader = data.DataLoader(v_dataset,
                               batch_size=v_batch_size,
                               shuffle=False,
                               num_workers=os.cpu_count(),
                               pin_memory=True)

    scheduler = nn.StepLR(optimizer)
    amp_scale = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
    with open(f'weights/{scheduler.__str__()}.csv', 'w') as summary:
        writer = csv.DictWriter(
            summary,
            fieldnames=['epoch', 't_loss', 'v_loss', 'acc@1', 'acc@5'])
        writer.writeheader()
        best_acc1 = 0
        for epoch in range(0, epochs):
            print(('\n' + '%10s' * 2) % ('epoch', 'loss'))
            t_bar = tqdm.tqdm(t_loader, total=len(t_loader))
            model.train()
            t_loss = util.AverageMeter()
            v_loss = util.AverageMeter()
            for images, target in t_bar:
                loss, _, _, _ = batch_fn(images, target, model, device,
                                         t_criterion)
                optimizer.zero_grad()
                amp_scale.scale(loss).backward()
                amp_scale.step(optimizer)
                amp_scale.update()

                ema.update(model)
                torch.cuda.synchronize()
                t_loss.update(loss.item(), images.size(0))

                t_bar.set_description(('%10s' + '%10.4g') %
                                      ('%g/%g' % (epoch + 1, epochs), loss))
            top1 = util.AverageMeter()
            top5 = util.AverageMeter()

            ema_model = ema.model.eval()
            with torch.no_grad():
                for images, target in tqdm.tqdm(v_loader, ('%10s' * 2) %
                                                ('acc@1', 'acc@5')):
                    loss, acc1, acc5, output = batch_fn(
                        images, target, ema_model, device, v_criterion, False)
                    torch.cuda.synchronize()
                    v_loss.update(loss.item(), output.size(0))
                    top1.update(acc1.item(), images.size(0))
                    top5.update(acc5.item(), images.size(0))
                acc1, acc5 = top1.avg, top5.avg
                print('%10.3g' * 2 % (acc1, acc5))

            scheduler.step(epoch + 1)
            writer.writerow({
                'epoch': epoch + 1,
                't_loss': str(f'{t_loss.avg:.4f}'),
                'v_loss': str(f'{v_loss.avg:.4f}'),
                'acc@1': str(f'{acc1:.3f}'),
                'acc@5': str(f'{acc5:.3f}')
            })
            util.save_checkpoint({'state_dict': ema.model.state_dict()},
                                 acc1 > best_acc1)
            best_acc1 = max(acc1, best_acc1)
    torch.cuda.empty_cache()
Exemple #15
0
def main():
	global args, best_nmi, start_epoch
	
	os.makedirs('{}'.format(args.save_path), exist_ok=True)

	# logging configuration
	logger = create_logger('global_logger', log_file=os.path.join(args.save_path,'log.txt'))
	logger.info('{}'.format(args))
	logger.info('{}'.format(coeff))
	tb_logger = SummaryWriter(args.save_path)

	# Construct Networks (Encoder, dim_loss)
	model = models.__dict__[args.arch](args.num_classes).cuda()
	print("=> created encoder '{}'".format(args.arch))
	
	toy_input = torch.zeros([5, 3, args.input_size, args.input_size]).cuda()
	arch_info = get_dim(model, toy_input, args.layers, args.c_layer)

	dim_loss = models.__dict__['DIM_Loss'](arch_info).cuda()

	# optimizer
	para_dict = itertools.chain(filter(lambda x: x.requires_grad, model.parameters()),
		  filter(lambda x: x.requires_grad, dim_loss.parameters()))
	optimizer = torch.optim.RMSprop(para_dict, lr=args.lr, alpha=0.9)

	# criterions
	crit_graph = nn.BCELoss().cuda()
	crit_label = WeightedBCE().cuda()
	crit_c = nn.CrossEntropyLoss().cuda()

	# optionally resume from a checkpoint
	if args.resume:
		logger.info("=> loading checkpoint '{}'".format(args.resume))
		start_epoch, best_nmi = load_checkpoint(model, dim_loss, optimizer, args.resume)

	# data loading
	dataset = McDataset(
		  args.root, 
		  args.source, 
		  transform=transforms.ToTensor())
	dataloader = torch.utils.data.DataLoader(
		  dataset, batch_size=args.large_bs,
		  num_workers=args.workers, pin_memory=True, shuffle=True)
	datagen = ImageDataGenerator(
		  rotation_range=20,
		  width_shift_range=0.18,
		  height_shift_range=0.18,
		  channel_shift_range=0.1,
		  horizontal_flip=True,
		  rescale=0.95,
		  zoom_range=[0.85,1.15])

	
	for epoch in range(start_epoch, args.epochs):
	
		end = time.time()

		# Evaluation
		nmi, acc, ari = test(dataloader, model, epoch, tb_logger)
	
		# saving checkpoint
		is_best_nmi = nmi > best_nmi
		best_nmi = max(nmi, best_nmi)
		save_checkpoint({
			  'epoch': epoch, 
			  'model': model.state_dict(), 
			  'dim_loss': dim_loss.state_dict(), 
			  'best_nmi': best_nmi,
			  'optimizer': optimizer.state_dict()}, 
			  is_best_nmi, args.save_path + '/ckpt') 

		# training
		train(dataloader, model, dim_loss, crit_label, crit_graph, crit_c, optimizer, epoch, datagen, tb_logger)
Exemple #16
0
def train(gpu, ngpus_per_node, argss):
    global args
    args = argss
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    teacher_model = None
    if args.teacher_model_path:
        teacher_model = PSPNet(layers=args.teacher_layers,
                               classes=args.classes,
                               zoom_factor=args.zoom_factor)
        kd_path = 'alpha_' + str(args.alpha) + '_Temp_' + str(args.temperature)
        args.save_path = os.path.join(args.save_path, kd_path)
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
    if args.arch == 'dct':
        model = DCTNet(layers=args.layers, classes=args.classes, vec_dim=300)
        # modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
        # modules_new = [model.cls, model.aux]  # DCT4
        modules_ori = [model.cp, model.sp, model.head]
        modules_new = []
        args.index_split = len(
            modules_ori
        )  # the module after index_split need multiply 10 at learning rate

    params_list = []
    for module in modules_ori:
        params_list.append(dict(params=module.parameters(), lr=args.base_lr))
    for module in modules_new:
        params_list.append(
            dict(params=module.parameters(), lr=args.base_lr * 10))
    # args.index_split = 5
    optimizer = torch.optim.SGD(params_list,
                                lr=args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    if args.sync_bn:
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if teacher_model is not None:
            teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(
                teacher_model)

    if main_process():
        global logger, writer
        logger = get_logger()
        writer = SummaryWriter(args.save_path)  # tensorboardX
        logger.info(args)
        logger.info("=> creating model ...")
        logger.info("Classes: {}".format(args.classes))
        logger.info(model)
        if teacher_model is not None:
            logger.info(teacher_model)
    if args.distributed:
        torch.cuda.set_device(gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
        args.workers = int(
            (args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
                                                          device_ids=[gpu])
        if teacher_model is not None:
            teacher_model = torch.nn.parallel.DistributedDataParallel(
                teacher_model.cuda(), device_ids=[gpu])

    else:
        model = torch.nn.DataParallel(model.cuda())
        if teacher_model is not None:
            teacher_model = torch.nn.DataParallel(teacher_model.cuda())

    if teacher_model is not None:
        checkpoint = torch.load(
            args.teacher_model_path,
            map_location=lambda storage, loc: storage.cuda())
        teacher_model.load_state_dict(checkpoint['state_dict'], strict=False)
        print("=> loading teacher checkpoint '{}'".format(
            args.teacher_model_path))

    if args.use_ohem:
        criterion = OhemCELoss(thresh=0.7,
                               ignore_index=args.ignore_label).cuda(gpu)
    else:
        criterion = nn.CrossEntropyLoss(
            ignore_index=args.ignore_label).cuda(gpu)

    kd_criterion = None
    if teacher_model is not None:
        kd_criterion = KDLoss(ignore_index=args.ignore_label).cuda(gpu)

    if args.weight:
        if os.path.isfile(args.weight):
            if main_process():
                logger.info("=> loading weight: '{}'".format(args.weight))
            checkpoint = torch.load(args.weight)
            model.load_state_dict(checkpoint['state_dict'])
            if main_process():
                logger.info("=> loaded weight '{}'".format(args.weight))
        else:
            if main_process():
                logger.info("=> mp weight found at '{}'".format(args.weight))

    best_mIoU_val = 0.0
    if args.resume:
        if os.path.isfile(args.resume):
            if main_process():
                logger.info("=> loading checkpoint '{}'".format(args.resume))
            # Load all tensors onto GPU
            checkpoint = torch.load(
                args.resume, map_location=lambda storage, loc: storage.cuda())
            args.start_iter = checkpoint['iteration']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_mIoU_val = checkpoint['best_mIoU_val']
            if main_process():
                logger.info("=> loaded checkpoint '{}' (iteration {})".format(
                    args.resume, checkpoint['iteration']))
        else:
            if main_process():
                logger.info("=> no checkpoint found at '{}'".format(
                    args.resume))

    value_scale = 255
    ## RGB mean & std
    rgb_mean = [0.485, 0.456, 0.406]
    rgb_mean = [item * value_scale for item in rgb_mean]
    rgb_std = [0.229, 0.224, 0.225]
    rgb_std = [item * value_scale for item in rgb_std]

    # DCT mean & std
    dct_mean = dct_mean_std.train_upscaled_static_mean
    dct_mean = [item * value_scale for item in dct_mean]
    dct_std = dct_mean_std.train_upscaled_static_std
    dct_std = [item * value_scale for item in dct_std]

    train_transform = transform.Compose([
        transform.RandScale([args.scale_min, args.scale_max]),
        transform.RandRotate([args.rotate_min, args.rotate_max],
                             padding=rgb_mean,
                             ignore_label=args.ignore_label),
        transform.RandomGaussianBlur(),
        transform.RandomHorizontalFlip(),
        transform.Crop([args.train_h, args.train_w],
                       crop_type='rand',
                       padding=rgb_mean,
                       ignore_label=args.ignore_label),
        # transform.GetDctCoefficient(),
        transform.ToTensor(),
        transform.Normalize(mean=rgb_mean, std=rgb_std)
    ])
    train_data = dataset.SemData(split='train',
                                 img_type='rgb',
                                 data_root=args.data_root,
                                 data_list=args.train_list,
                                 transform=train_transform)
    # train_transform = transform_rgbdct.Compose([
    #     transform_rgbdct.RandScale([args.scale_min, args.scale_max]),
    #     transform_rgbdct.RandRotate([args.rotate_min, args.rotate_max], padding=rgb_mean, ignore_label=args.ignore_label),
    #     transform_rgbdct.RandomGaussianBlur(),
    #     transform_rgbdct.RandomHorizontalFlip(),
    #     transform_rgbdct.Crop([args.train_h, args.train_w], crop_type='rand', padding=rgb_mean, ignore_label=args.ignore_label),
    #     transform_rgbdct.GetDctCoefficient(),
    #     transform_rgbdct.ToTensor(),
    #     transform_rgbdct.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std,  std_dct=dct_std)])
    # train_data = dataset.SemData(split='train', img_type='rgb&dct', data_root=args.data_root, data_list=args.train_list, transform=train_transform)
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data)
    else:
        train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, \
        shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, \
        sampler=train_sampler, drop_last=True)
    if args.evaluate:
        # val_h = int(args.base_h * args.scale)
        # val_w = int(args.base_w * args.scale)
        val_transform = transform.Compose([
            transform.Crop([args.train_h, args.train_w],
                           crop_type='center',
                           padding=rgb_mean,
                           ignore_label=args.ignore_label),
            # transform.Resize(size=(val_h, val_w)),
            # transform.GetDctCoefficient(),
            transform.ToTensor(),
            transform.Normalize(mean=rgb_mean, std=rgb_std)
        ])
        val_data = dataset.SemData(split='val',
                                   img_type='rgb',
                                   data_root=args.data_root,
                                   data_list=args.val_list,
                                   transform=val_transform)
        # val_transform = transform_rgbdct.Compose([
        #     transform_rgbdct.Crop([args.train_h, args.train_w], crop_type='center', padding=rgb_mean, ignore_label=args.ignore_label),
        #     # transform.Resize(size=(val_h, val_w)),
        #     transform_rgbdct.GetDctCoefficient(),
        #     transform_rgbdct.ToTensor(),
        #     transform_rgbdct.Normalize(mean_rgb=rgb_mean, mean_dct=dct_mean, std_rgb=rgb_std,  std_dct=dct_std)])
        # val_data = dataset.SemData(split='val', img_type='rgb&dct', data_root=args.data_root, data_list=args.val_list, transform=val_transform)
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_data)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, \
            shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    # Training Loop
    batch_time = AverageMeter()
    data_time = AverageMeter()
    main_loss_meter = AverageMeter()
    # aux_loss_meter = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    # switch to train mode
    model.train()
    if teacher_model is not None:
        teacher_model.eval()
    end = time.time()
    max_iter = args.max_iter

    data_iter = iter(train_loader)
    epoch = 0
    for current_iter in range(args.start_iter, args.max_iter):
        try:
            input, target = next(data_iter)
            if not target.size(0) == args.batch_size:
                raise StopIteration
        except StopIteration:
            epoch += 1
            if args.distributed:
                train_sampler.set_epoch(epoch)
                if main_process():
                    logger.info('train_sampler.set_epoch({})'.format(epoch))
            data_iter = iter(train_loader)
            input, target = next(data_iter)
            # need to update the AverageMeter for new epoch
            main_loss_meter = AverageMeter()
            # aux_loss_meter = AverageMeter()
            loss_meter = AverageMeter()
            intersection_meter = AverageMeter()
            union_meter = AverageMeter()
            target_meter = AverageMeter()
        # measure data loading time
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        # input = [input[0].cuda(non_blocking=True), input[1].cuda(non_blocking=True)]
        target = target.cuda(non_blocking=True)

        # compute output
        # main_out, aux_out = model(input)
        main_out = model(input)
        # _, H, W = target.shape
        # main_out = F.interpolate(main_out, size=(H, W), mode='bilinear', align_corners=True)
        main_loss = criterion(main_out, target)
        # aux_loss = criterion(aux_out, target)

        if not args.multiprocessing_distributed:
            # main_loss, aux_loss = torch.mean(main_loss), torch.mean(aux_loss)
            main_loss = torch.mean(main_loss)
        # loss = main_loss + args.aux_weight * aux_loss
        loss = main_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n = target.size(0)
        # if args.multiprocessing_distributed:
        #     main_loss, aux_loss, loss = main_loss.detach() * n, aux_loss * n, loss * n  # not considering ignore pixels
        #     count = target.new_tensor([n], dtype=torch.long)
        #     dist.all_reduce(main_loss), dist.all_reduce(aux_loss), dist.all_reduce(loss), dist.all_reduce(count)
        #     n = count.item()
        #     main_loss, aux_loss, loss = main_loss / n, aux_loss / n, loss / n
        if args.multiprocessing_distributed:
            main_loss, loss = main_loss.detach(
            ) * n, loss * n  # not considering ignore pixels
            count = target.new_tensor([n], dtype=torch.long)
            dist.all_reduce(main_loss), dist.all_reduce(loss), dist.all_reduce(
                count)
            n = count.item()
            main_loss, loss = main_loss / n, loss / n

        main_out = main_out.detach().max(1)[1]
        intersection, union, target = intersectionAndUnionGPU(
            main_out, target, args.classes, args.ignore_label)
        if args.multiprocessing_distributed:
            dist.all_reduce(intersection), dist.all_reduce(
                union), dist.all_reduce(target)
        intersection, union, target = intersection.cpu().numpy(), union.cpu(
        ).numpy(), target.cpu().numpy()
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)

        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        main_loss_meter.update(main_loss.item(), n)
        # aux_loss_meter.update(aux_loss.item(), n)
        loss_meter.update(loss.item(), n)
        batch_time.update(time.time() - end)
        end = time.time()

        # Using Poly strategy to change the learning rate
        current_lr = poly_learning_rate(args.base_lr,
                                        current_iter,
                                        max_iter,
                                        power=args.power)
        for index in range(0, args.index_split
                           ):  # args.index_split = 5 -> ResNet has 5 stages
            optimizer.param_groups[index]['lr'] = current_lr
        for index in range(args.index_split, len(optimizer.param_groups)):
            optimizer.param_groups[index]['lr'] = current_lr * 10

        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m),
                                                    int(t_s))

        iter_log = current_iter + 1
        if iter_log % args.print_freq == 0 and main_process():
            logger.info('Iter [{}/{}] '
                        'LR: {lr:.3e}, '
                        'ETA: {remain_time}, '
                        'Data: {data_time.val:.3f} ({data_time.avg:.3f}), '
                        'Batch: {batch_time.val:.3f} ({batch_time.avg:.3f}), '
                        'MainLoss: {main_loss_meter.val:.4f}, '
                        # 'AuxLoss: {aux_loss_meter.val:.4f}, '
                        'Loss: {loss_meter.val:.4f}, '
                        'Accuracy: {accuracy:.4f}.'.format(
                            iter_log,
                            args.max_iter,
                            lr=current_lr,
                            remain_time=remain_time,
                            data_time=data_time,
                            batch_time=batch_time,
                            main_loss_meter=main_loss_meter,
                            # aux_loss_meter=aux_loss_meter,
                            loss_meter=loss_meter,
                            accuracy=accuracy))
        if main_process():
            writer.add_scalar('loss_train_batch', main_loss_meter.val,
                              iter_log)
            writer.add_scalar('mIoU_train_batch',
                              np.mean(intersection / (union + 1e-10)),
                              iter_log)
            writer.add_scalar('mAcc_train_batch',
                              np.mean(intersection / (target + 1e-10)),
                              iter_log)
            writer.add_scalar('allAcc_train_batch', accuracy, iter_log)

        if iter_log % len(
                train_loader
        ) == 0 or iter_log == max_iter:  # for each epoch or the max interation
            iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
            accuracy_class = intersection_meter.sum / (target_meter.sum +
                                                       1e-10)
            mIoU_train = np.mean(iou_class)
            mAcc_train = np.mean(accuracy_class)
            allAcc_train = sum(
                intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
            loss_train = main_loss_meter.avg
            if main_process():
                logger.info('Train result at iteration [{}/{}]: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'\
                    .format(iter_log, max_iter, mIoU_train, mAcc_train, allAcc_train))
                writer.add_scalar('loss_train', loss_train, iter_log)
                writer.add_scalar('mIoU_train', mIoU_train, iter_log)
                writer.add_scalar('mAcc_train', mAcc_train, iter_log)
                writer.add_scalar('allAcc_train', allAcc_train, iter_log)

        # if iter_log % args.save_freq == 0:
            is_best = False
            if args.evaluate:
                loss_val, mIoU_val, mAcc_val, allAcc_val = validate(
                    val_loader, model, criterion)
                model.train()  # the mode change from eval() to train()
                if main_process():
                    writer.add_scalar('loss_val', loss_val, iter_log)
                    writer.add_scalar('mIoU_val', mIoU_val, iter_log)
                    writer.add_scalar('mAcc_val', mAcc_val, iter_log)
                    writer.add_scalar('allAcc_val', allAcc_val, iter_log)

                    if best_mIoU_val < mIoU_val:
                        is_best = True
                        best_mIoU_val = mIoU_val
                        logger.info('==>The best val mIoU: %.3f' %
                                    (best_mIoU_val))

            if main_process():
                save_checkpoint(
                    {
                        'iteration': iter_log,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_mIoU_val': best_mIoU_val
                    }, is_best, args.save_path)
                logger.info('Saving checkpoint to:{}/iter_{}.pth or last.pth with mIoU:{:.3f}'\
                    .format(args.save_path, iter_log, mIoU_val))
                if is_best:
                    logger.info('Saving checkpoint to:{}/best.pth with mIoU:{:.3f}'\
                        .format(args.save_path, best_mIoU_val))

    if main_process():
        writer.close(
        )  # it must close the writer, otherwise it will appear the EOFError!
        logger.info(
            '==>Training done! The best val mIoU during training: %.3f' %
            (best_mIoU_val))
Exemple #17
0
def train():
    config_print()
    print("SEED : {}".format(GLOBAL_SEED))
    os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_ids
    set_seed(GLOBAL_SEED)
    best_prec1 = 0.
    write_log = 'logs/%s' % config.dataset_tag + config.gpu_ids
    write_val_log = 'logs/val%s' % config.dataset_tag + config.gpu_ids
    write = SummaryWriter(log_dir=write_log)
    write_val = SummaryWriter(log_dir=write_val_log)
    data_config = getDatasetConfig(config.dataset_tag)

    #load dataset
    train_dataset = CustomDataset(data_config['train'],
                                  data_config['train_root'],
                                  True)  #txt.file,train_root_dir,is_traning
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True,
                              num_workers=config.workers,
                              pin_memory=True,
                              worker_init_fn=_init_fn)
    val_dataset = CustomDataset(data_config['val'], data_config['val_root'],
                                False)
    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.workers,
                            pin_memory=True)  #,worker_init_fn=_init_fn)

    print('Dataset Name:{dataset_name}, Train:[{train_num}], Val:[{val_num}]'.
          format(dataset_name=config.dataset_tag,
                 train_num=len(train_dataset),
                 val_num=len(val_dataset)))

    # define model

    net = init_model(pretrained=True,
                     model_name=config.model_name,
                     class_num=config.class_num)

    # gup config
    use_gpu = torch.cuda.is_available() and config.use_gpu
    if use_gpu:
        net = net.cuda()
    gpu_ids = [int(r) for r in config.gpu_ids.split(',')]
    if use_gpu and config.multi_gpu:
        net = torch.nn.DataParallel(net, device_ids=gpu_ids)

    # define potimizer
    assert config.optimizer in ['sgd', 'adam'], 'optim name not found!'
    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=config.learning_rate,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif config.optimizer == 'adam':
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=config.learning_rate,
                                     weight_decay=config.weight_decay)

    # define learning scheduler
    assert config.scheduler in ['plateau', 'step', 'muilt_step',
                                'cosine'], 'scheduler not supported!!!'
    if config.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               'min',
                                                               patience=3,
                                                               factor=0.1)
    elif config.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=2,
                                                    gamma=0.9)
    elif config.scheduler == 'muilt_step':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=[30, 100],
                                                         gamma=0.1)
    elif config.scheduler == 'cosine':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config.epochs)

    # define loss
    criterion = torch.nn.CrossEntropyLoss()

    if use_gpu:
        criterion = criterion.cuda()
        # train val parameters dict
    state = {
        'model': net,
        'train_loader': train_loader,
        'val_loader': val_loader,
        'criterion': criterion,
        'config': config,
        'optimizer': optimizer,
        'write': write,
        'write_val': write_val
    }
    # define resume
    start_epoch = 0
    if config.resume:
        ckpt = torch.load(config.resume)
        net.load_state_dict(ckpt['state_dict'])
        start_epoch = ckpt['epoch']
        best_prec1 = ckpt['best_prec1']
        optimizer.load_state_dict(ckpt['optimizer'])

        # train and val
    engine = Engine()
    for e in range(start_epoch, config.epochs + 1):
        if config.scheduler in ['step', 'muilt_step']:
            scheduler.step()
        lr_train = get_lr(optimizer)
        print("Start epoch %d ==========,lr=%f" % (e, lr_train))
        train_prec, train_loss = engine.train(state, e)
        prec1, val_loss = engine.validate(state, e)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': e + 1,
                'state_dict': net.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best, config.checkpoint_path)
        write.add_scalars("Accurancy", {'train': train_prec, 'val': prec1}, e)
        write.add_scalars("Loss", {'train': train_loss, 'val': val_loss}, e)
        if config.scheduler == 'plateau':
            scheduler.step(val_loss)