Example #1
0
def main():
    global lmbda, n_step, node_labels, n_nodes

    # setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))

    if not torch.cuda.is_available():
        raise SystemExit('GPU is needed')

    os.system('echo $CUDA_VISIBLE_DEVICES')

    # setup data loader
    data_loader = get_hierdataset_cifar(cfg['data'],
                                        cfg['training']['batch_size'],
                                        cfg['training']['n_workers'],
                                        ['train'])
    n_step = int(
        len(data_loader['train'].dataset) //
        float(cfg['training']['batch_size']))
    node_labels = data_loader['node_labels']
    nodes = data_loader['nodes']
    n_nodes = len(nodes)

    # setup Deep-RTC model (feature extractor + classifier)
    n_gpu = torch.cuda.device_count()
    model_fe = get_model(cfg['model']['fe']).cuda()
    model_fe = nn.DataParallel(model_fe, device_ids=range(n_gpu))

    model_cls = get_model(cfg['model']['cls'], nodes).cuda()
    model_cls = nn.DataParallel(model_cls, device_ids=range(n_gpu))

    model_pivot = get_model(cfg['model']['pivot']).cuda()
    model_pivot = nn.DataParallel(model_pivot, device_ids=range(n_gpu))

    # loss function
    criterion = nn.CrossEntropyLoss(reduction='none')
    lmbda = cfg['training']['lmbda']

    # setup optimizer
    opt_main_cls, opt_main_params = get_optimizer(
        cfg['training']['optimizer_main'])
    cnn_params = list(model_fe.named_parameters()) + list(
        model_cls.named_parameters())
    cnn_params = add_weight_decay(cnn_params, opt_main_params['weight_decay'])
    opt_main_params.pop('weight_decay', None)
    opt_main = opt_main_cls(cnn_params, **opt_main_params)
    logger.info('Using optimizer {}'.format(opt_main))

    cudnn.benchmark = True

    # load checkpoint
    start_ep = 0
    if cfg['training']['resume'].get('model', None):
        resume = cfg['training']['resume']
        if os.path.isfile(resume['model']):
            logger.info("Loading model from checkpoint '{}'".format(
                resume['model']))
            checkpoint = torch.load(resume['model'])
            model_fe.module.load_state_dict(
                cvt2normal_state(checkpoint['model_fe_state']))
            model_cls.module.load_state_dict(
                cvt2normal_state(checkpoint['model_cls_state']))
            if resume['param_only'] is False:
                start_ep = checkpoint['epoch']
                opt_main.load_state_dict(checkpoint['opt_main_state'])
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                resume['model'], checkpoint['epoch']))
        else:
            logger.info("No checkpoint found at '{}'".format(resume['model']))

    print('Start training from epoch {}'.format(start_ep))
    logger.info('Start training from epoch {}'.format(start_ep))

    for ep in range(start_ep, cfg['training']['epoch']):

        if (ep + 1) <= 5:
            assign_learning_rate(opt_main, lr=opt_main_params['lr'] * (ep + 1))

        if (ep + 1) % 160 == 0 or (ep + 1) % 180 == 0:
            adjust_learning_rate(opt_main, decay_rate=0.1)

        train(data_loader['train'], model_fe, model_cls, model_pivot, opt_main,
              ep, criterion)

        if (ep + 1) % cfg['training']['save_interval'] == 0:
            state = {
                'epoch': ep + 1,
                'model_fe_state': model_fe.state_dict(),
                'model_cls_state': model_cls.state_dict(),
                'opt_main_state': opt_main.state_dict()
            }
            ckpt_path = os.path.join(writer.file_writer.get_logdir(),
                                     "ep-{ep}_model.pkl")
            save_path = ckpt_path.format(ep=ep + 1)
            last_path = ckpt_path.format(ep=ep + 1 -
                                         cfg['training']['save_interval'])
            torch.save(state, save_path)
            if os.path.isfile(last_path):
                os.remove(last_path)
            print_str = '[Checkpoint]: {} saved'.format(save_path)
            print(print_str)
            logger.info(print_str)
Example #2
0
num_train_batches = int(len(train_dataset) / batch_size)
num_val_batches = int(len(val_dataset) / batch_size)
print("num_train_batches:", num_train_batches)
print("num_val_batches:", num_val_batches)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=1)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=1)

params = add_weight_decay(network, l2_value=0.0001)
optimizer = torch.optim.Adam(params, lr=learning_rate)

with open("/root/deeplabv3/data/cityscapes/meta/class_weights.pkl",
          "rb") as file:  # (needed for python3)
    class_weights = np.array(pickle.load(file))
class_weights = torch.from_numpy(class_weights)
class_weights = Variable(class_weights.type(torch.FloatTensor)).cuda()

# loss function
loss_fn = nn.CrossEntropyLoss(weight=class_weights)

epoch_losses_train = []
epoch_losses_val = []
for epoch in range(num_epochs):
    print("###########################")
Example #3
0
def main(opt):
    setup_seed(opt.seed)
    if torch.cuda.is_available():
        device = torch.device('cuda')
        torch.cuda.set_device(opt.gpu_id)
    else:
        device = torch.device('cpu')

    log_dir = opt.log_dir + '/' + opt.network + '-' + str(opt.layers)
    utils.mkdir(log_dir)

    model = get_model(opt)
    # model = nn.DataParallel(model, device_ids=[1, 2, 3])
    # model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    # model = nn.DataParallel(model, device_ids=[4, 5, 6, 7])
    model = nn.DataParallel(model, device_ids=[0, 1, 2, 3, 4, 5, 6, 7])
    # model = convert_model(model)
    model = model.to(device)

    summary_writer = SummaryWriter(logdir=log_dir)
    weight = None
    if opt.classes == 9:
        weight = torch.tensor([1.8, 1, 1, 1.2, 1, 1.6, 1.2, 1.4, 1],
                              device=device)
    elif opt.classes == 8:
        weight = torch.tensor([1.8, 1, 1.2, 1.6, 1, 1.2, 1.8, 1],
                              device=device)
    elif opt.classes == 2:
        weight = torch.tensor([1., 1.5], device=device)

    if opt.criterion == 'lsr':
        criterion = LabelSmoothSoftmaxCE(weight=weight,
                                         use_focal_loss=opt.use_focal,
                                         reduction='sum').cuda()
    elif opt.criterion == 'focal':
        # criterion = FocalLoss(alpha=1, gamma=2, reduction='sum')
        criterion = FocalLoss2()
    elif opt.criterion == 'ce':
        criterion = nn.CrossEntropyLoss(weight=weight, reduction='sum').cuda()
    elif opt.criterion == 'bce':
        criterion = nn.BCEWithLogitsLoss(weight=weight, reduction='sum').cuda()

    if opt.classes > 2:
        # all data
        images, labels = utils.read_data(
            os.path.join(opt.root_dir, opt.train_dir),
            os.path.join(opt.root_dir, opt.train_label), opt.train_less,
            opt.clean_data)
    elif opt.classes == 2:
        # 2 categories
        images, labels = utils.read_ice_snow_data(
            os.path.join(opt.root_dir, opt.train_dir),
            os.path.join(opt.root_dir, opt.train_label))

    # 7 categories
    # images, labels = utils.read_non_ice_snow_data(
    #         os.path.join(opt.root_dir, opt.train_dir),
    #         os.path.join(opt.root_dir, opt.train_label))

    ################ devide set #################
    if opt.fore:
        train_im, train_label = images[opt.num_val:], labels[opt.num_val:]
        val_im, val_label = images[:opt.num_val], labels[:opt.num_val]
    else:
        train_im, train_label = images[:-opt.num_val], labels[:-opt.num_val]
        val_im, val_label = images[-opt.num_val:], labels[-opt.num_val:]

    if opt.cu_mode:
        train_data_1 = train_im[:4439], train_label[:4439]
        train_data_2 = train_im[:5385], train_label[:5385]
        train_data_3 = train_im, train_label
        # train_datas = [train_data_1, train_data_2, train_data_3]
        train_datas = [train_data_2, train_data_3]
        opt.num_epochs //= len(train_datas)
    else:
        train_datas = [(train_im, train_label)]
    val_data = val_im, val_label
    #########################################

    if opt.retrain:
        state_dict = torch.load(opt.model_dir + '/' + opt.network + '-' +
                                str(opt.layers) + '-' + str(opt.crop_size) +
                                '_model.ckpt')
        model.load_state_dict(state_dict)

    ################ optimizer #################
    if opt.retrain and not opt.teacher_mode:
        if opt.network in ['effnet']:
            for param in model.module.parameters():
                param.requires_grad = False
            for param in model.module._fc.parameters():
                param.requires_grad = True
            # for param in model.module._swish.parameters():
            #     param.requires_grad = True
            for param in model.module.model._bn1.parameters():
                param.requires_grad = True

        elif opt.network in ['resnet', 'resnext', \
                             'resnext_wsl_32x8d', 'resnext_wsl_32x16d', 'resnext_wsl_32x32d', \
                             'resnext_swsl']:
            for param in model.parameters():
                param.requires_grad = False
            for param in model.module.fc.parameters():
                param.requires_grad = True
            for param in model.module.layer4[2].bn3.parameters():
                # for param in model.module.layer4[2].bn2.parameters():
                param.requires_grad = True

        elif opt.network in ['pnasnet_m', 'senet_m']:
            for param in model.module.parameters():
                param.requires_grad = False
            for param in model.module.classifier.parameters():
                param.requires_grad = True
            if opt.network == 'senet_m':
                for param in model.module.features.layer4.parameters():
                    # for param in model.module.features.layer4[2].bn3.parameters():
                    param.requires_grad = True

        elif opt.network in ['inception_v3']:
            for param in model.parameters():
                param.requires_grad = False
            for param in model.fc.parameters():
                param.requires_grad = True
            for param in model.Mixed_7c.parameters():
                param.requires_grad = True
        else:
            for param in model.module.parameters():
                param.requires_grad = False
            for param in model.module.last_linear.parameters():
                param.requires_grad = True
            if opt.network in ['se_resnext50_32x4d', 'se_resnext101_32x4d']:
                for param in model.module.layer4[2].bn3.parameters():
                    param.requires_grad = True
            elif opt.network in ['senet154']:
                for param in model.module.layer4.parameters():
                    param.requires_grad = True
            elif opt.network in ['xception']:
                for param in model.module.bn4.parameters():
                    param.requires_grad = True
            elif opt.network in ['inceptionresnetv2']:
                for param in model.module.conv2d_7b.bn.parameters():
                    param.requires_grad = True
            elif opt.network in ['inceptionv4']:
                for param in model.module.features[-1].branch3.parameters():
                    param.requires_grad = True
            elif opt.network in ['fixpnas']:
                for param in model.module.cell_11.parameters():
                    param.requires_grad = True
                for param in model.module.cell_10.parameters():
                    param.requires_grad = True
                for param in model.module.cell_9.parameters():
                    param.requires_grad = True
        params = filter(lambda p: p.requires_grad, model.parameters())
    else:
        if opt.network in ['effnet'] and not opt.retrain:
            params = utils.add_weight_decay(model.module.model, 1e-4)
            params.append({
                'params': model.module._fc.parameters(),
                'lr': opt.lr * 10
            })
        else:
            params = utils.add_weight_decay(model, 1e-4)

    ################ optimizer #################
    optimizer = get_optimizer(opt, params, weight_decay=1e-4)
    if opt.scheduler in ['step', 'multistep', 'plateau', 'exponential']:
        scheduler = get_schedule(opt, optimizer)
    ############################################

    crop_size = opt.crop_size - 128
    val_transforms = my_transform(False, crop_size)
    val_dataset = WeatherDataset(val_data[0], val_data[1], val_transforms)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=False,
                                             num_workers=8,
                                             pin_memory=True)
    for train_data in train_datas:
        val_dis = np.bincount(val_label) + 1e-20
        train_dis = np.bincount(train_data[1])
        print(val_dis, opt.num_val)
        print(train_dis, len(train_data[1]))

        train_transforms = my_transform(True, crop_size, opt.cutout,
                                        opt.n_holes, opt.length, opt.auto_aug,
                                        opt.rand_aug)
        train_dataset = WeatherDataset(train_data[0], train_data[1],
                                       train_transforms)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=opt.batch_size,
                                                   shuffle=True,
                                                   num_workers=8,
                                                   drop_last=True,
                                                   pin_memory=True)

        loader = {'train': train_loader, 'val': val_loader}

        ################ scheduler #################
        if opt.scheduler in ['warmup', 'cycle', 'cos', 'cosw', 'sgdr']:
            scheduler = get_schedule(opt, optimizer, len(train_loader))
        ############################################

        model, acc = train_model(loader,
                                 model,
                                 criterion,
                                 optimizer,
                                 summary_writer,
                                 scheduler=scheduler,
                                 scheduler_name=opt.scheduler,
                                 num_epochs=opt.num_epochs,
                                 device=device,
                                 is_inception=opt.is_inception,
                                 mixup=opt.mixup,
                                 cutmix=opt.cutmix,
                                 alpha=opt.alpha,
                                 val_dis=val_dis)
        utils.mkdir(opt.model_dir)
        torch.save(
            model.state_dict(), opt.model_dir + '/' + opt.network + '-' +
            str(opt.layers) + '-' + str(crop_size) + '_model.ckpt')
Example #4
0
def main(opt):
    setup_seed(opt.seed)

    device = torch.device('cpu')

    model = get_model(opt)

    # model = nn.DataParallel(model, device_ids=[opt.gpu_id, opt.gpu_id+1])
    # model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    # model = nn.DataParallel(model, device_ids=[4, 5, 6, 7])
    # model = convert_model(model)
    model = model.to(device)

    # summary(model, (3, opt.crop_size, opt.crop_size))

    # large_lr_layers = list(map(id, model.module._fc.parameters()))
    # small_lr_layers = filter(lambda p:id(p) not in large_lr_layers, model.parameters())
    # optimizer = torch.optim.SGD([
    #                     {"params":model.fc.parameters()},
    #                     {"params":small_lr_layers,"lr":opt.lr/10}
    #                     ],lr = opt.lr, momentum=0.9, weight_decay=1e-3)

    # optimizer = torch.optim.Adam([
    #                     {"params":model.module.fc.parameters()},
    #                     {"params":small_lr_layers,"lr":opt.lr/10}
    #                     ],lr = opt.lr, weight_decay=5e-4)

    summary_writer = SummaryWriter(logdir=opt.log_dir)
    weight = torch.tensor([1.8, 1, 1, 1.2, 1, 1.6, 1, 1.2, 1], device=device)
    # weight = torch.tensor([1.8, 1, 1, 1], device=device)
    # weight = torch.tensor([1.8, 1, 1, 1.2, 1.6, 1.2, 1], device=device)
    # weight = torch.tensor([3., 1.], device=device)

    if opt.criterion == 'lsr':
        criterion = LabelSmoothSoftmaxCE(weight=weight, use_focal_loss=opt.use_focal, reduction='sum').cuda()
    elif opt.criterion == 'focal':
        criterion = FocalLoss(alpha=1, gamma=2, reduction='sum')
    elif opt.criterion == 'ce':
        criterion = nn.CrossEntropyLoss(weight=weight, reduction='sum').cuda()
    elif opt.criterion == 'bce':
        criterion = nn.BCEWithLogitsLoss(weight=weight, reduction='sum').cuda()

    # all data
    images, labels = utils.read_data(
            os.path.join(opt.root_dir, opt.train_dir),
            os.path.join(opt.root_dir, opt.train_label),
            opt.train_less, opt.clean_data)

    # 2 categories
    # images, labels = utils.read_ice_snow_data(
    #         os.path.join(opt.root_dir, opt.train_dir),
    #         os.path.join(opt.root_dir, opt.train_label))

    # 7 categories
    # images, labels = utils.read_non_ice_snow_data(
    #         os.path.join(opt.root_dir, opt.train_dir),
    #         os.path.join(opt.root_dir, opt.train_label))

    ################ devide set #################
    if opt.fore:
        train_im, train_label = images[opt.num_val:], labels[opt.num_val:]
        val_im, val_label = images[:opt.num_val], labels[:opt.num_val]
    else:
        train_im, train_label = images[:-opt.num_val], labels[:-opt.num_val]
        val_im, val_label = images[-opt.num_val:], labels[-opt.num_val:]
    val_dis =np.bincount(val_label)+1e-20
    train_dis = np.bincount(train_label)
    
    print(val_dis, opt.num_val)
    print(train_dis, len(train_label))

    train_data = train_im, train_label
    val_data = val_im, val_label
    #########################################

    if opt.retrain:
        state_dict = torch.load(
                opt.model_dir+'/'+opt.network+'-'+str(opt.layers)+'-'+str(opt.crop_size)+'_model.ckpt')
        model.load_state_dict(state_dict)

    ################ optimizer #################
    params = utils.add_weight_decay(model.module, 1e-4)
    optimizer = get_optimizer(opt, params)

    for crop_size in [opt.crop_size]:

        train_transforms = my_transform(True, crop_size, opt.cutout, opt.n_holes, opt.length, opt.auto_aug
                                        ,opt.raugment, opt.N, opt.M
                                        )
        train_dataset = WeatherDataset(train_data[0], train_data[1], train_transforms)
        val_transforms = my_transform(False, crop_size)
        val_dataset = WeatherDataset(val_data[0], val_data[1], val_transforms)

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=8,
                                               drop_last=True)
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=False, num_workers=4)
        loader = {'train':train_loader, 'val':val_loader}

        ################ scheduler #################
        scheduler = get_schedule(opt, optimizer, len(train_loader))
        ############################################

        model, acc = train_model(loader, model, criterion, optimizer, summary_writer,
                            scheduler=scheduler, scheduler_name=opt.scheduler, num_epochs=opt.num_epochs, device=device,
                            is_inception=opt.is_inception, mixup=opt.mixup, alpha=opt.alpha,
                            val_dis=val_dis)
        utils.mkdir(opt.model_dir)
        utils.mkdir(opt.log_dir+'/'+opt.network+'-'+str(opt.layers))
        torch.save(model.state_dict(),
            opt.model_dir+'/'+opt.network+'-'+str(opt.layers)+'-'+str(crop_size)+'_model.ckpt')
Example #5
0
    checkpoint = torch.load(args.init_from)
    loaded_params = {}
    for k, v in checkpoint['net'].items():
        if not k.startswith("module."):
            loaded_params["module." + k] = v
        else:
            loaded_params[k] = v

    net_state_dict = net.state_dict()
    net_state_dict.update(loaded_params)
    net.load_state_dict(net_state_dict)
else:
    warnings.warn("No checkpoint file is provided !!!")

params = utils.add_weight_decay(net,
                                weight_decay=args.wd,
                                skip_keys=['delta', 'alpha'])
criterion = nn.CrossEntropyLoss()

# Setup optimizer
# ----------------------------
if args.optimizer == 'sgd':
    print("==> Use SGD optimizer")
    optimizer = optim.SGD(params,
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.wd)
elif args.optimizer == 'adam':
    print("==> Use Adam optimizer")
    optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.wd)