Example #1
0
def train_module(_opt):
    #def train_module(_train_path, _train_save, _resume_snapshot,_batchsize):
    #parser = argparse.ArgumentParser()
    #parser.add_argument('--epoch', type=int, default=10, help='epoch number')
    #parser.add_argument('--lr', type=float, default=3e-4, help='learning rate')
    #parser.add_argument('--batchsize', type=int, default=_batchsize, help='training batch size')
    #parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
    #parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
    #parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
    #parser.add_argument('--decay_epoch', type=int, default=50, help='every n epochs decay learning rate')
    #parser.add_argument('--train_path', type=str, default=_train_path)
    #parser.add_argument('--train_save', type=str, default=_train_save)
    #parser.add_argument('--resume_snapshot', type=str, default=_resume_snapshot)
    #opt = parser.parse_args()

    opt = _opt

    # ---- build models ----
    torch.cuda.set_device(0)
    model = Network(channel=32, n_class=1).cuda()

    model.load_state_dict(torch.load(opt.resume_snapshot))

    params = model.parameters()
    optimizer = torch.optim.Adam(params, opt.lr)

    image_root = '{}/Imgs/'.format(opt.train_path)
    gt_root = '{}/GT/'.format(opt.train_path)
    edge_root = '{}/Edge/'.format(opt.train_path)

    train_loader = get_loader(image_root,
                              gt_root,
                              edge_root,
                              batchsize=opt.batchsize,
                              trainsize=opt.trainsize)
    total_step = len(train_loader)

    print("#" * 20, "Start Training", "#" * 20)

    for epoch in range(1, opt.epoch):
        adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
        trainer(train_loader=train_loader,
                model=model,
                optimizer=optimizer,
                epoch=epoch,
                opt=opt,
                total_step=total_step)
Example #2
0
    elif (not opt.is_pseudo) and (not opt.is_semi):
        train_save = 'Inf-Net'
    else:
        print('Use custom save path')
        train_save = opt.train_save

    # ---- calculate FLOPs and Params ----
    if opt.is_thop:
        from Code.utils.utils import CalParams
        x = torch.randn(1, 3, opt.trainsize, opt.trainsize).cuda()
        CalParams(model, x)

    # ---- load training sub-modules ----
    BCE = torch.nn.BCEWithLogitsLoss()

    params = model.parameters()
    optimizer = torch.optim.Adam(params, opt.lr)

    image_root = '{}/Imgs/'.format(opt.train_path)
    gt_root = '{}/GT/'.format(opt.train_path)
    edge_root = '{}/Edge/'.format(opt.train_path)

    train_loader = get_loader(image_root,
                              gt_root,
                              edge_root,
                              batchsize=opt.batchsize,
                              trainsize=opt.trainsize,
                              num_workers=opt.num_workers)
    total_step = len(train_loader)

    # ---- start !! -----
Example #3
0
def create_model(opt):
    model = Inf_Net(channel=opt.net_channel,
                    n_class=opt.n_classes).to(opt.device)
    params = model.parameters()
    optimizer = torch.optim.Adam(params, opt.lr)
    return model, optimizer