コード例 #1
0
def train(config):
    summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr)

    ## inputs
    inputs = {'b_t_1': None, 'b_t': None, 's_t_1': None, 's_t': None}
    inputs = collections.OrderedDict(sorted(inputs.items(),
                                            key=lambda t: t[0]))

    ## model
    print(toGreen('Loading Model...'))
    moduleNetwork = Network().to(device)
    moduleNetwork.apply(weights_init)
    moduleNetwork_gt = Network().to(device)
    print(moduleNetwork)

    ## checkpoint manager
    ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode,
                                config.max_ckpt_num)
    moduleNetwork.load_state_dict(
        torch.load('./network/network-default.pytorch'))
    moduleNetwork_gt.load_state_dict(
        torch.load('./network/network-default.pytorch'))

    ## data loader
    print(toGreen('Loading Data Loader...'))
    data_loader = Data_Loader(config,
                              is_train=True,
                              name='train',
                              thread_num=config.thread_num)
    data_loader_test = Data_Loader(config,
                                   is_train=False,
                                   name="test",
                                   thread_num=config.thread_num)

    data_loader.init_data_loader(inputs)
    data_loader_test.init_data_loader(inputs)

    ## loss, optim
    print(toGreen('Building Loss & Optim...'))
    MSE_sum = torch.nn.MSELoss(reduction='sum')
    MSE_mean = torch.nn.MSELoss()
    optimizer = optim.Adam(moduleNetwork.parameters(),
                           lr=config.lr_init,
                           betas=(config.beta1, 0.999))
    errs = collections.OrderedDict()

    print(toYellow('======== TRAINING START ========='))
    max_epoch = 10000
    itr = 0
    for epoch in np.arange(max_epoch):

        # train
        while True:
            itr_time = time.time()

            inputs, is_end = data_loader.get_feed()
            if is_end: break

            if config.loss == 'image':
                flow_bb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_bs = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_sb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_ss = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)

                with torch.no_grad():
                    flow_ss_gt = torch.nn.functional.interpolate(
                        input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']),
                        size=(config.height, config.width),
                        mode='bilinear',
                        align_corners=False)
                    s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like(
                        inputs['s_t_1'], device=device),
                                                 tensorFlow=flow_ss_gt)

                s_t_warped_bb = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_bb)
                s_t_warped_bs = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_bs)
                s_t_warped_sb = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_sb)
                s_t_warped_ss = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_ss)

                s_t_warped_bb_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_bb)
                s_t_warped_bs_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_bs)
                s_t_warped_sb_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_sb)
                s_t_warped_ss_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_ss)

                optimizer.zero_grad()

                errs['MSE_bb'] = MSE_sum(
                    s_t_warped_bb * s_t_warped_bb_mask,
                    inputs['s_t']) / s_t_warped_bb_mask.sum()
                errs['MSE_bs'] = MSE_sum(
                    s_t_warped_bs * s_t_warped_bs_mask,
                    inputs['s_t']) / s_t_warped_bs_mask.sum()
                errs['MSE_sb'] = MSE_sum(
                    s_t_warped_sb * s_t_warped_sb_mask,
                    inputs['s_t']) / s_t_warped_sb_mask.sum()
                errs['MSE_ss'] = MSE_sum(
                    s_t_warped_ss * s_t_warped_ss_mask,
                    inputs['s_t']) / s_t_warped_ss_mask.sum()

                errs['MSE_bb_mask_shape'] = MSE_mean(s_t_warped_bb_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['MSE_bs_mask_shape'] = MSE_mean(s_t_warped_bs_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['MSE_sb_mask_shape'] = MSE_mean(s_t_warped_sb_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask,
                                                     s_t_warped_ss_mask_gt)

                errs['total'] = errs['MSE_bb'] + errs['MSE_bs'] + errs['MSE_sb'] + errs['MSE_ss'] \
                              + errs['MSE_bb_mask_shape'] + errs['MSE_bs_mask_shape'] + errs['MSE_sb_mask_shape'] + errs['MSE_ss_mask_shape']

            if config.loss == 'image_ss':
                flow_ss = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                with torch.no_grad():
                    flow_ss_gt = torch.nn.functional.interpolate(
                        input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']),
                        size=(config.height, config.width),
                        mode='bilinear',
                        align_corners=False)
                    s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like(
                        inputs['s_t_1'], device=device),
                                                 tensorFlow=flow_ss_gt)

                s_t_warped_ss = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_ss)
                s_t_warped_ss_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_ss)

                optimizer.zero_grad()

                errs['MSE_ss'] = MSE_sum(
                    s_t_warped_ss * s_t_warped_ss_mask,
                    inputs['s_t']) / s_t_warped_ss_mask.sum()
                errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['total'] = errs['MSE_ss'] + errs['MSE_ss_mask_shape']

            if config.loss == 'flow_only':
                flow_bb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_bs = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_sb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_ss = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)

                s_t_warped_ss = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_ss)

                with torch.no_grad():
                    flow_ss_gt = torch.nn.functional.interpolate(
                        input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']),
                        size=(config.height, config.width),
                        mode='bilinear',
                        align_corners=False)

                optimizer.zero_grad()

                # liteflow_flow_only
                errs['MSE_bb_ss'] = MSE_mean(flow_bb, flow_ss_gt)
                errs['MSE_bs_ss'] = MSE_mean(flow_bs, flow_ss_gt)
                errs['MSE_sb_ss'] = MSE_mean(flow_sb, flow_ss_gt)
                errs['MSE_ss_ss'] = MSE_mean(flow_ss, flow_ss_gt)
                errs['total'] = errs['MSE_bb_ss'] + errs['MSE_bs_ss'] + errs[
                    'MSE_sb_ss'] + errs['MSE_ss_ss']

            errs['total'].backward()
            optimizer.step()

            lr = adjust_learning_rate(optimizer, epoch, config.decay_rate,
                                      config.decay_every, config.lr_init)

            if itr % config.write_log_every_itr == 0:
                summary.add_scalar('loss/loss_mse', errs['total'].item(), itr)
                vutils.save_image(inputs['s_t_1'].detach().cpu(),
                                  '{}/{}_1_input.png'.format(
                                      config.LOG_DIR.sample, itr),
                                  nrow=3,
                                  padding=0,
                                  normalize=False)
                vutils.save_image(s_t_warped_ss.detach().cpu(),
                                  '{}/{}_2_warped_ss.png'.format(
                                      config.LOG_DIR.sample, itr),
                                  nrow=3,
                                  padding=0,
                                  normalize=False)
                vutils.save_image(inputs['s_t'].detach().cpu(),
                                  '{}/{}_3_gt.png'.format(
                                      config.LOG_DIR.sample, itr),
                                  nrow=3,
                                  padding=0,
                                  normalize=False)

                if config.loss == 'image_ss':
                    vutils.save_image(s_t_warped_ss_mask.detach().cpu(),
                                      '{}/{}_4_s_t_wapred_ss_mask.png'.format(
                                          config.LOG_DIR.sample, itr),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)
                elif config.loss != 'flow_only':
                    vutils.save_image(s_t_warped_bb_mask.detach().cpu(),
                                      '{}/{}_4_s_t_wapred_bb_mask.png'.format(
                                          config.LOG_DIR.sample, itr),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)

            if itr % config.refresh_image_log_every_itr == 0:
                remove_file_end_with(config.LOG_DIR.sample, '*.png')

            print_logs('TRAIN',
                       config.mode,
                       epoch,
                       itr_time,
                       itr,
                       data_loader.num_itr,
                       errs=errs,
                       lr=lr)
            itr += 1

        if epoch % config.write_ckpt_every_epoch == 0:
            ckpt_manager.save_ckpt(moduleNetwork,
                                   epoch,
                                   score=errs['total'].item())
コード例 #2
0
ファイル: train.py プロジェクト: lyx997/Majiang

def weights_init(m):  ##定义参数初始化函数
    classname = m.__class__.__name__  # m作为一个形参,原则上可以传递很多的内容,为了实现多实参传递,每一个moudle要给出自己的name. 所以这句话就是返回m的名字。具体例子下边会详细说明。
    if classname.find(
            'Conv') != -1:  #find()函数,实现查找classname中是否含有conv字符,没有返回-1;有返回0.
        torch.nn.init.normal_(
            m.weight.data, 0.0, 0.02
        )  #m.weight.data表示需要初始化的权重。 nn.init.normal_()表示随机初始化采用正态分布,均值为0,标准差为0.02.
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data,
                                0)  # nn.init.constant_()表示将偏差定义为常量0


net.apply(weights_init)
net = net.double()
print(2)

dataLoader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=batch_size,
                                         num_workers=2,
                                         shuffle=True)
optim = torch.optim.Adam(net.parameters(), lr=0.1, betas=(0.5, 0.999))

fixed_x, fixed_label = dataset.get_fixed()
fixed_x = torch.tensor(fixed_x)
fixed_label = torch.tensor(fixed_label)
print(3)

for epoch in range(1):
コード例 #3
0
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
        print("-- making output folder: " + output_folder)

    train_dataset = torchvision.datasets.ImageFolder(
        root=train_foler, transform=torchvision.transforms.ToTensor())
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=10, shuffle=True, num_workers=8)

    dev_dataset = torchvision.datasets.ImageFolder(
        root=dev_foler, transform=torchvision.transforms.ToTensor())
    dev_dataloader = torch.utils.data.DataLoader(
        dev_dataset, batch_size=10, shuffle=True, num_workers=8)
    num_classes = len(train_dataset.classes)
    network = Network(num_feats, hidden_sizes, num_classes)
    network.apply(init_weights)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        network.parameters(),
        lr=lr,
        weight_decay=weightDecay,
        momentum=0.9)
    print("-- ready to rock to network.train")
    network.train()
    print("-- finish network.train")
    network.to(device)
    print("-- ready to rock to train()")
    train(network, train_dataloader, dev_dataloader, numEpochs)
    print("-- finish train()")
    torch.save(network.cpu(), output_folder + "/network.npy")