Exemplo n.º 1
0
def train_dist(model, config, step, x, pre_model_file, model_file=None):
    parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")
    parser.add_argument("--local_rank", type=int, default=0)
    args = parser.parse_args()
    local_rank = args.local_rank
    print('******************* local_rank', local_rank)
    torch.cuda.set_device(local_rank)
    torch.distributed.init_process_group(backend="nccl", init_method="env://")
    assert torch.distributed.is_initialized()
    batch_size = config.gpus * config.batch_size_per_GPU
    print('--------batch_size--------', batch_size)

    model = model(config)
    print(model)
    model.eval()
    model_dic = model.state_dict()

    pretrained_dict = torch.load(pre_model_file, map_location='cpu')
    a = pretrained_dict['classifier.0.weight']
    b = pretrained_dict['classifier.0.bias']
    c = pretrained_dict['classifier.3.weight']
    d = pretrained_dict['classifier.3.bias']
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dic
    }
    print(len(pretrained_dict))
    model_dic.update(pretrained_dict)
    print(list(model_dic.keys()))
    # model_dic['fast.fast_head.0.weight'] = a
    # model_dic['fast.fast_head.0.bias'] = b
    # model_dic['fast.fast_head.2.weight'] = c
    # model_dic['fast.fast_head.2.bias'] = d
    model.load_state_dict(model_dic)

    if step > 0:

        model.load_state_dict(torch.load(model_file, map_location='cpu'))
        print(model_file)
    else:
        print(pre_model_file)

    parameters = list(model.parameters())
    for i in range(8):
        parameters[i].requires_grad = False

    model = torch.nn.parallel.DistributedDataParallel(
        model.cuda(),
        device_ids=[local_rank],
        output_device=local_rank,
        # this should be removed if we update BatchNorm stats
        broadcast_buffers=False,
    )

    train_params = list(model.parameters())[8:]

    bias_p = []
    weight_p = []
    for name, p in model.named_parameters():
        if 'bias' in name:
            bias_p.append(p)
        else:
            weight_p.append(p)
    print(len(weight_p), len(bias_p))
    lr = config.lr * config.batch_size_per_GPU
    if lr >= 60000 * x:
        lr = lr / 10
    if lr >= 80000 * x:
        lr = lr / 10
    print('lr        ******************', lr)

    opt = torch.optim.SGD(
        [{
            'params': weight_p,
            'weight_decay': config.weight_decay,
            'lr': lr
        }, {
            'params': bias_p,
            'lr': lr * config.bias_lr_factor
        }],
        momentum=0.9,
    )

    epochs = 10000
    flag = False
    dataset = Read_Data(config)
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    dataloader = DataLoader(dataset,
                            batch_size=config.batch_size_per_GPU,
                            sampler=train_sampler,
                            collate_fn=func,
                            drop_last=True,
                            pin_memory=True)
    for epoch in range(epochs):
        train_sampler.set_epoch(epoch)

        for imgs, bboxes, num_b, num_H, num_W in dataloader:

            loss = model(imgs, bboxes, num_b, num_H, num_W)
            loss = loss / imgs.shape[0]
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(train_params, 35, norm_type=2)
            opt.step()

            # torch.cuda.empty_cache()
            if step % 20 == 0 and local_rank == 0:
                print(datetime.now(), 'loss:%.4f' % (loss),
                      opt.param_groups[0]['lr'], step)
            step += 1

            if (step == int(60000 * x) or step == int(80000 * x)):
                for param_group in opt.param_groups:
                    param_group['lr'] = param_group['lr'] / 10
                    print('***************************', param_group['lr'],
                          local_rank)
            if ((step <= 10000 and step % 1000 == 0) or step % 5000 == 0
                    or step == 1) and local_rank == 0:
                torch.save(
                    model.module.state_dict(),
                    './models/vgg16_cascade_%dx_%d_1_%d.pth' %
                    (x, step, local_rank))
            if step >= 90010 * x:
                flag = True
                break
        if flag:
            break
    if local_rank == 0:
        torch.save(
            model.module.state_dict(),
            './models/vgg16_cascade_%dx_final_1_%d.pth' % (x, local_rank))
def train(model, config, step, x, pre_model_file, model_file=None):
    model = model(config)
    model.eval()
    model_dic = model.state_dict()

    pretrained_dict = torch.load(pre_model_file, map_location='cpu')

    a = pretrained_dict['classifier.0.weight']
    b = pretrained_dict['classifier.0.bias']
    c = pretrained_dict['classifier.3.weight']
    d = pretrained_dict['classifier.3.bias']
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dic
    }
    print(len(pretrained_dict))
    model_dic.update(pretrained_dict)
    print(list(model_dic.keys()))
    model_dic['fast.fast_head.0.weight'] = a
    model_dic['fast.fast_head.0.bias'] = b
    model_dic['fast.fast_head.2.weight'] = c
    model_dic['fast.fast_head.2.bias'] = d
    model.load_state_dict(model_dic)

    if step > 0:
        model.load_state_dict(torch.load(model_file, map_location='cpu'))
        print(model_file)
    else:
        print(pre_model_file)

    train_params = list(model.parameters())
    for p in train_params[:8]:
        p.requires_grad = False

    cuda(model)
    train_params = list(model.parameters())[8:]

    lr = config.lr * config.batch_size_per_GPU
    if step >= 60000 * x:
        lr = lr / 10
    if step >= 80000 * x:
        lr = lr / 10
    print('lr        ******************', lr)
    print('weight_decay     ******************', config.weight_decay)

    if True:
        bias_p = []
        weight_p = []
        print(len(train_params))
        for name, p in model.named_parameters():
            if 'bias' in name:
                bias_p.append(p)
            else:
                weight_p.append(p)
        print(len(weight_p), len(bias_p))
        opt = torch.optim.SGD(
            [{
                'params': weight_p,
                'weight_decay': config.weight_decay,
                'lr': lr
            }, {
                'params': bias_p,
                'lr': lr * config.bias_lr_factor
            }],
            momentum=0.9,
        )
    else:
        bias_p = []
        weight_p = []
        bn_weight_p = []
        print(len(train_params))
        for name, p in model.named_parameters():
            print(name, p.shape)
            if len(p.shape) == 1:
                if 'bias' in name:
                    bias_p.append(p)
                else:
                    bn_weight_p.append(p)
            else:
                weight_p.append(p)
        print(len(weight_p), len(bias_p), len(bn_weight_p))
        opt = torch.optim.SGD(
            [{
                'params': weight_p,
                'weight_decay': config.weight_decay,
                'lr': lr
            }, {
                'params': bn_weight_p,
                'lr': lr
            }, {
                'params': bias_p,
                'lr': lr * config.bias_lr_factor
            }],
            momentum=0.9,
        )
    dataset = Read_Data(config)
    dataloader = DataLoader(dataset,
                            batch_size=config.batch_size_per_GPU,
                            collate_fn=func,
                            shuffle=True,
                            drop_last=True,
                            pin_memory=True,
                            num_workers=6)

    epochs = 10000
    flag = False
    print('start:  step=', step)
    for epoch in range(epochs):
        for imgs, bboxes, num_b, num_H, num_W in dataloader:
            loss = model(imgs, bboxes, num_b, num_H, num_W)
            loss = loss / imgs.shape[0]
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(train_params, 10, norm_type=2)
            opt.step()
            if step % 20 == 0:
                print(datetime.now(), 'loss:%.4f' % loss,
                      'rpn_cls_loss:%.4f' % model.a,
                      'rpn_box_loss:%.4f' % model.b,
                      'fast_cls_loss:%.4f' % model.c,
                      'fast_box_loss:%.4f' % model.d, model.fast_num,
                      model.fast_num_P, opt.param_groups[0]['lr'], step)
            step += 1

            if step == int(60000 * x) or step == int(80000 * x):
                for param_group in opt.param_groups:
                    param_group['lr'] = param_group['lr'] / 10
                    print('*******************************************',
                          param_group['lr'])

            if (step <= 10000
                    and step % 1000 == 0) or step % 5000 == 0 or step == 1:
                torch.save(model.state_dict(),
                           './models/vgg16_cascade_%d_2.pth' % step)
            if step >= 90010 * x:
                flag = True
                break
        if flag:
            break
    torch.save(model.state_dict(), './models/vgg16_cascade_final_2.pth')
Exemplo n.º 3
0
def train(model, config, step, x, pre_model_file, model_file=None):
    dataset = Read_Data(config)
    dataloader = DataLoader(dataset, batch_size=config.batch_size_per_GPU, collate_fn=func,
                            shuffle=True, drop_last=True, pin_memory=True)
    model = model(config)
    print(model)
    model.eval()
    model_dic = model.state_dict()
    pretrained_dict = torch.load(pre_model_file, map_location='cpu')
    pretrained_dict = {'features.' + k: v for k, v in pretrained_dict.items() if 'features.' + k in model_dic}
    print('*******', len(pretrained_dict))
    model_dic.update(pretrained_dict)
    model.load_state_dict(model_dic)
    if step > 0:
        model.load_state_dict(torch.load(model_file, map_location='cpu'))
        print(model_file)
    else:
        print(pre_model_file)
    cuda(model)

    train_params = list(model.parameters())
    bias_p = []
    weight_p = []

    for name, p in model.named_parameters():
        if 'bias' in name:
            bias_p.append(p)
        else:
            weight_p.append(p)
    print(len(weight_p), len(bias_p))
    lr = config.lr * config.batch_size_per_GPU
    if lr >= 60000 * x:
        lr = lr / 10
    if lr >= 80000 * x:
        lr = lr / 10
    print('lr        ******************', lr)
    opt = torch.optim.SGD(
        [{'params': weight_p, 'weight_decay': config.weight_decay, 'lr': lr},
         {'params': bias_p, 'lr': lr * config.bias_lr_factor}],
        momentum=0.9, )
    scheduler = WarmupMultiStepLR(opt, [60000 * x, 80000 * x], warmup_factor=1 / 3, warmup_iters=500)
    epochs = 10000
    flag = False
    print('start:  step=', step)
    for epoch in range(epochs):

        for imgs, bboxes, num_b, num_H, num_W in dataloader:

            loss = model(imgs, bboxes, num_b, num_H, num_W)
            loss = loss / imgs.shape[0]
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(train_params, 10, norm_type=2)
            opt.step()
            scheduler.step()
            if step % 20 == 0:
                print(datetime.now(), 'loss:%.4f' % loss, 'rpn_cls_loss:%.4f' % model.a,
                      'rpn_box_loss:%.4f' % model.b,
                      'fast_cls_loss:%.4f' % model.c, 'fast_box_loss:%.4f' % model.d,
                      model.fast_num,
                      model.fast_num_P, opt.param_groups[0]['lr'], step)
            step += 1

            # if step == int(60000 * x) or step == int(80000 * x):
            #     for param_group in opt.param_groups:
            #         param_group['lr'] = param_group['lr'] / 10
            #         print('*********************************', param_group['lr'])

            if (step <= 10000 and step % 1000 == 0) or step % 5000 == 0 or step == 1:
                torch.save(model.state_dict(), './models/FPN_50_%d_1.pth' % step)

            if step >= 90010 * x:
                flag = True
                break
        if flag:
            break
    torch.save(model.state_dict(), './models/FPN_50_final_1.pth')