예제 #1
0
def main():
    global args, viz

    print('=> loading datasets...')
    dset = datasets.SSTSeq(
        args.train_root,
        seq_len=args.seq_len,
        target_seq_len=args.target_seq_len,
        zones=args.train_zones,
        rescale_method=args.rescale,
        time_slice=slice(None, 3000),
        normalize_uv=True,
    )

    test_dset = datasets.SSTSeq(
        args.train_root,
        seq_len=args.seq_len,
        target_seq_len=args.test_target_seq_len,
        zones=args.test_zones,
        rescale_method=args.rescale,
        time_slice=slice(3000, None),
        normalize_uv=True,
    )

    # train_indices = range(0, int(len(dset) * args.split))
    # val_indices = range(int(len(dset) * args.split), len(dset))

    train_loader = DataLoader(
        dset,
        batch_size=args.batch_size,
        # sampler=SubsetRandomSampler(train_indices),
        num_workers=args.workers,
        shuffle=True,
        pin_memory=True)
    # val_loader = DataLoader(dset,
    #                         batch_size=args.batch_size,
    #                         sampler=SubsetRandomSampler(val_indices),
    #                         num_workers=args.workers,
    #                         pin_memory=True
    #                         )
    test_loader = DataLoader(test_dset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers,
                             pin_memory=True)
    print('len(test)', len(test_dset))

    splits = {
        'train': train_loader,
        # 'valid': val_loader,
        'test': test_loader,
    }

    estimator = estimators.ConvDeconvEstimator(input_channels=args.seq_len,
                                               upsample_mode=args.upsample)
    warp = warps.__dict__[args.warp]()
    print("=> creating warping scheme '{}'".format(args.warp))

    #     to_save = {
    #         'epoch': epoch,
    #         'estimator': estimator,
    #         'warp': warp,
    #         'optim':optimizer,
    #         'err_obs': results['test']['pl'],
    #         'err_aae': results['test']['err_aae'],
    #     }
    load_path = os.path.join(args.load_root, args.load_fn)
    print(f'Loading {load_path} ...')
    loaded = torch.load(load_path)
    print('loaded', loaded)
    estimator = loaded['estimator']
    warp = loaded['warp']

    # estimator = estimator.cuda()
    # warp = warp.cuda()

    photo_loss = torch.nn.MSELoss()
    smooth_loss = losses.SmoothnessLoss(torch.nn.MSELoss())
    div_loss = losses.DivergenceLoss(torch.nn.MSELoss())
    magn_loss = losses.MagnitudeLoss(torch.nn.MSELoss())
    sim_loss = torch.nn.functional.cosine_similarity

    cudnn.benchmark = True
    optimizer = torch.optim.Adam(estimator.parameters(),
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    _x, _ys = torch.Tensor(), torch.Tensor()

    if not args.no_cuda:
        print('=> to cuda')
        _x, _ys = _x.cuda(), _ys.cuda()
        warp.cuda(), estimator.cuda()

    viz_wins = {}
    for epoch in range(1, args.epochs + 1):

        results = {}
        for split, dl in splits.items():
            if split != 'test': continue

            meters = AverageMeters()

            if split == 'train':
                estimator.train(), warp.train()
            else:
                estimator.eval(), warp.eval()

            if args.save_ims:
                if args.test_target_seq_len == 10:
                    index = [
                        8, 15, 45, 59, 63, 64, 74, 76, 83, 90, 97, 107, 111,
                        125, 136, 139, 155, 171, 176, 182, 204, 213, 215, 218,
                        223, 224, 226, 232, 260, 263
                    ]
                elif args.test_target_seq_len == 20:
                    index = [
                        63, 83, 101, 109, 118, 135, 149, 150, 153, 158, 170,
                        189, 200, 229, 234
                    ]
                else:
                    print('dataset', args.test_target_seq_len)
                index = [i - args.seq_len + 1 for i in index]
                print(f'index: {index}')

            for i, (input, targets, w_targets) in enumerate(dl):

                if args.save_ims:
                    if i not in index: continue
                print('i=', i)

                with torch.no_grad():
                    _x.resize_(input.size()).copy_(input)
                    _ys.resize_(targets.size()).copy_(targets)
                    _ys = _ys.transpose(0, 1).unsqueeze(2)
                    x, ys = Variable(_x), Variable(_ys)

                    pl = 0
                    sl = 0
                    dl = 0
                    ml = 0
                    err_aee = 0

                    ims = []
                    ws = []
                    last_im = x[:, -1].unsqueeze(1)
                    for j, y in enumerate(ys):

                        w = estimator(x)
                        im = warp(x[:, -1].unsqueeze(1), w)
                        x = torch.cat([x[:, 1:], im], 1)

                        curr_pl = photo_loss(im, y)
                        pl += curr_pl
                        sl += smooth_loss(w)
                        dl += div_loss(w)
                        ml += magn_loss(w)

                        err_aee += sim_loss(w, w_targets[:,
                                                         j].to('cuda')).mean()
                        # print(w_targets[:, j].shape, w.shape, 'ok')

                        ims.append(im.cpu().data.numpy())
                        ws.append(w.cpu().data.numpy())

                    pl /= args.test_target_seq_len
                    sl /= args.test_target_seq_len
                    dl /= args.test_target_seq_len
                    ml /= args.test_target_seq_len
                    err_aee /= args.test_target_seq_len
                    # print('err', err_aee)

                    loss = pl + args.smooth_coef * sl + args.div_coef * dl + args.magn_coef * ml

                    # if split == 'train':
                    #     optimizer.zero_grad()
                    #     loss.backward()
                    #     optimizer.step()

                    meters.update(
                        dict(
                            loss=loss.item(),
                            pl=pl.item(),
                            dl=dl.item(),
                            sl=sl.item(),
                            ml=ml.item(),
                            err_aae=err_aee.item(),
                            # err_unobs=.item(),
                            # err_obs
                        ),
                        n=x.size(0))
                    images = [
                        # ('target', {
                        #     'in': input.transpose(0, 1).numpy(),
                        #     'out': ys.cpu().data.numpy()
                        #     }
                        # ),
                        ('im', {
                            'out': ims
                        }),
                        ('ws', {
                            'out': ws
                        }),
                        # ('ws_target', {
                        #     'out': w_targets.transpose(0, 1).numpy()
                        #     }
                        # ),
                    ]
                    if args.save_ims:
                        import matplotlib.pyplot as plt
                        images_save_root = '/net/drunk/debezenac/data/flow_icml/saved_images/png/sst/test/adv/' + str(
                            args.test_target_seq_len)
                        # images_save_root = '/net/drunk/debezenac/data/flow_icml/saved_images/png/sst/adv/20'
                        nplots = 1
                        images_save_fn = os.path.join(
                            images_save_root,
                            f'images_{load_path.replace("/", "_")}_{i + args.seq_len - 1}.png'
                        )
                        pl = plot.plot_results(images,
                                               min(nplots, args.batch_size),
                                               cmap='tarn',
                                               renorm=False)
                        print(images_save_fn)
                        plt.savefig(images_save_fn)

            if not args.no_plot:
                images = [
                    ('target', {
                        'in': input.transpose(0, 1).numpy(),
                        'out': ys.cpu().data.numpy()
                    }),
                    ('im', {
                        'out': ims
                    }),
                    ('ws', {
                        'out': ws
                    }),
                    ('ws_target', {
                        'out': w_targets.transpose(0, 1).numpy()
                    }),
                ]

                plt = plot.from_matplotlib(plot.plot_results(images))
                viz.image(
                    plt.transpose(2, 0, 1),
                    opts=dict(
                        title='{}, epoch {}'.format(split.upper(), epoch)),
                    win=list(splits).index(split),
                )

            results[split] = meters.avgs()
            print('\n\nEpoch: {} {}: {}\t'.format(epoch, split, meters))
            print('seq_len:', args.test_target_seq_len)

        # transposing the results dict
        res = {}
        legend = []
        for split in results:
            legend.append(split)
            for metric, avg in results[split].items():
                res.setdefault(metric, [])
                res[metric].append(avg)
        # plotting
        for metric in res:
            y = np.expand_dims(np.array(res[metric]), 0)
            x = np.array([[epoch] * len(results)])
            if epoch == 1:
                win = viz.line(X=x,
                               Y=y,
                               opts=dict(showlegend=True,
                                         legend=legend,
                                         title=metric))
                viz_wins[metric] = win
            else:
                viz.line(X=x,
                         Y=y,
                         opts=dict(showlegend=True,
                                   legend=legend,
                                   title=metric),
                         win=viz_wins[metric],
                         update='append')
예제 #2
0
파일: train.py 프로젝트: Panpanx/flow
def main():
    global args, viz

    print('=> loading datasets...')
    dset = datasets.SSTSeq(
        args.train_root,
        seq_len=args.seq_len,
        target_seq_len=args.target_seq_len,
        zones=args.train_zones,
        rescale_method=args.rescale,
    )

    test_dset = datasets.SSTSeq(
        args.test_root,
        seq_len=args.seq_len,
        target_seq_len=args.target_seq_len,
        zones=args.test_zones,
        rescale_method=args.rescale,
    )

    train_indices = range(0, int(len(dset) * args.split))
    val_indices = range(int(len(dset) * args.split), len(dset))

    train_loader = DataLoader(dset,
                              batch_size=args.batch_size,
                              sampler=SubsetRandomSampler(train_indices),
                              num_workers=args.workers,
                              pin_memory=True)
    val_loader = DataLoader(dset,
                            batch_size=args.batch_size,
                            sampler=SubsetRandomSampler(val_indices),
                            num_workers=args.workers,
                            pin_memory=True)
    test_loader = DataLoader(test_dset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers,
                             pin_memory=True)

    splits = {
        'train': train_loader,
        'valid': val_loader,
        'test': test_loader,
    }

    estimator = estimators.ConvDeconvEstimator(input_channels=args.seq_len,
                                               upsample_mode=args.upsample)
    warp = warps.__dict__[args.warp]()
    print("=> creating warping scheme '{}'".format(args.warp))

    estimator = estimator.cuda()
    warp = warp.cuda()

    photo_loss = torch.nn.MSELoss()
    smooth_loss = losses.SmoothnessLoss(torch.nn.MSELoss())
    div_loss = losses.DivergenceLoss(torch.nn.MSELoss())
    magn_loss = losses.MagnitudeLoss(torch.nn.MSELoss())

    cudnn.benchmark = True
    optimizer = torch.optim.Adam(estimator.parameters(),
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    _x, _ys = torch.Tensor(), torch.Tensor()

    if not args.no_cuda:
        print('=> to cuda')
        _x, _ys = _x.cuda(), _ys.cuda()
        warp.cuda(), estimator.cuda()

    viz_wins = {}
    for epoch in range(1, args.epochs + 1):

        results = {}
        for split, data_loader in splits.items():

            meters = AverageMeters()

            if split == 'train':
                estimator.train(), warp.train()
            else:
                estimator.eval(), warp.eval()

            for i, (inputs, targets) in enumerate(data_loader):

                _x.resize_(inputs.size()).copy_(inputs)
                _ys.resize_(targets.size()).copy_(targets)
                _ys = _ys.transpose(0, 1).unsqueeze(2)
                x, ys = Variable(_x), Variable(_ys)

                pl = 0
                sl = 0
                dl = 0
                ml = 0

                ims = []
                ws = []
                # last_im = x[:, -1].unsqueeze(1)
                for y in ys:

                    w = estimator(x)
                    im = warp(x[:, -1].unsqueeze(1), w)
                    x = torch.cat([x[:, 1:], im], 1)

                    curr_pl = photo_loss(im, y)
                    pl += torch.mean(curr_pl)
                    sl += smooth_loss(w)
                    dl += div_loss(w)
                    ml += magn_loss(w)

                    ims.append(im.cpu().data.numpy())
                    ws.append(w.cpu().data.numpy())

                pl /= args.target_seq_len
                sl /= args.target_seq_len
                dl /= args.target_seq_len
                ml /= args.target_seq_len

                loss = pl + args.smooth_coef * sl + args.div_coef * dl + args.magn_coef * ml

                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                meters.update(dict(
                    loss=loss.data[0],
                    pl=pl.data[0],
                    dl=dl.data[0],
                    sl=sl.data[0],
                    ml=ml.data[0],
                ),
                              n=x.size(0))

            if not args.no_plot:
                images = [
                    ('target', {
                        'in': inputs.transpose(0, 1).numpy(),
                        'out': ys.cpu().data.numpy()
                    }),
                    ('im', {
                        'out': ims
                    }),
                    ('ws', {
                        'out': ws
                    }),
                ]
                plt = plot.from_matplotlib(plot.plot_results(images))
                viz.image(
                    plt.transpose(2, 0, 1),
                    opts=dict(
                        title='{}, epoch {}'.format(split.upper(), epoch)),
                    win=list(splits).index(split),
                )

            results[split] = meters.avgs()
            print('\n\nEpoch: {} {}: {}\t'.format(epoch, split, meters))

        # transposing the results dict
        res = {}
        legend = []
        for split in results:
            legend.append(split)
            for metric, avg in results[split].items():
                res.setdefault(metric, [])
                res[metric].append(avg)
        # plotting
        for metric in res:
            y = np.expand_dims(np.array(res[metric]), 0)
            x = np.array([[epoch] * len(results)])
            if epoch == 1:
                win = viz.line(X=x,
                               Y=y,
                               opts=dict(showlegend=True,
                                         legend=legend,
                                         title=metric))
                viz_wins[metric] = win
            else:
                viz.line(X=x,
                         Y=y,
                         opts=dict(showlegend=True,
                                   legend=legend,
                                   title=metric),
                         win=viz_wins[metric],
                         update='append')
예제 #3
0
def main():
    global args, viz

    print('=> loading datasets...')
    dset = datasets.SSTSeq(
        args.train_root,
        seq_len=args.seq_len,
        target_seq_len=args.target_seq_len,
        zones=args.train_zones,
        rescale_method=args.rescale,
        time_slice=slice(None, 3000),
        normalize_uv=True,
    )

    test_dset = datasets.SSTSeq(
        args.train_root,
        seq_len=args.seq_len,
        target_seq_len=args.test_target_seq_len,
        zones=args.test_zones,
        rescale_method=args.rescale,
        time_slice=slice(3000, None),
        normalize_uv=True,
    )

    # train_indices = range(0, int(len(dset) * args.split))
    # val_indices = range(int(len(dset) * args.split), len(dset))

    train_loader = DataLoader(
        dset,
        batch_size=args.batch_size,
        # sampler=SubsetRandomSampler(train_indices),
        num_workers=args.workers,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )
    # val_loader = DataLoader(dset,
    #                         batch_size=args.batch_size,
    #                         sampler=SubsetRandomSampler(val_indices),
    #                         num_workers=args.workers,
    #                         pin_memory=True
    #                         )
    test_loader = DataLoader(test_dset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=True,
                             num_workers=args.workers,
                             pin_memory=True)

    splits = {
        'train': train_loader,
        # 'valid': val_loader,
        'test': test_loader,
    }

    estimator = estimators.ConvDeconvEstimator(input_channels=args.seq_len,
                                               upsample_mode=args.upsample)
    warp = warps.__dict__[args.warp]()
    print("=> creating warping scheme '{}'".format(args.warp))

    estimator = estimator.cuda()
    warp = warp.cuda()

    photo_loss = torch.nn.MSELoss()
    smooth_loss = losses.SmoothnessLoss(torch.nn.MSELoss())
    div_loss = losses.DivergenceLoss(torch.nn.MSELoss())
    magn_loss = losses.MagnitudeLoss(torch.nn.MSELoss())

    cudnn.benchmark = True
    optimizer = torch.optim.Adam(estimator.parameters(),
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    _x, _ys = torch.Tensor(), torch.Tensor()

    if not args.no_cuda:
        print('=> to cuda')
        _x, _ys = _x.cuda(), _ys.cuda()
        warp.cuda(), estimator.cuda()

    viz_wins = {}
    for epoch in range(1, args.epochs + 1):

        results = {}
        for split, dl in splits.items():

            meters = AverageMeters()

            if split == 'train':
                estimator.train(), warp.train()
            else:
                estimator.eval(), warp.eval()

            for i, (input, targets, w_targets) in enumerate(dl):

                _x.resize_(input.size()).copy_(input)
                _ys.resize_(targets.size()).copy_(targets)
                _ys = _ys.transpose(0, 1).unsqueeze(2)
                x, ys = Variable(_x), Variable(_ys)

                pl = 0
                sl = 0
                dl = 0
                ml = 0
                err_aee = 0

                ims = []
                ws = []
                last_im = x[:, -1].unsqueeze(1)
                for j, y in enumerate(ys):

                    w = estimator(x)
                    im = warp(x[:, -1].unsqueeze(1), w)
                    x = torch.cat([x[:, 1:], im], 1)

                    curr_pl = photo_loss(im, y)
                    pl += curr_pl
                    sl += smooth_loss(w)
                    dl += div_loss(w)
                    ml += magn_loss(w)
                    # print('okokok', w_targets.shape, 'w', w.shape)
                    # exit()

                    err_aee += losses.AAE(w, w_targets[:, j].to('cuda'))

                    ims.append(im.cpu().data.numpy())
                    ws.append(w.cpu().data.numpy())

                pl /= args.target_seq_len
                sl /= args.target_seq_len
                dl /= args.target_seq_len
                ml /= args.target_seq_len
                err_aee /= args.target_seq_len

                loss = pl + args.smooth_coef * sl + args.div_coef * dl + args.magn_coef * ml

                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                meters.update(
                    dict(
                        loss=loss.item(),
                        pl=pl.item(),
                        dl=dl.item(),
                        sl=sl.item(),
                        ml=ml.item(),
                        err_aae=err_aee.item(),
                        # err_unobs=.item(),
                        # err_obs
                    ),
                    n=x.size(0))

            if not args.no_plot:
                images = [
                    ('target', {
                        'in': input.transpose(0, 1).numpy(),
                        'out': ys.cpu().data.numpy()
                    }),
                    ('im', {
                        'out': ims
                    }),
                    ('ws', {
                        'out': ws
                    }),
                ]
                plt = plot.from_matplotlib(plot.plot_results(images))
                viz.image(
                    plt.transpose(2, 0, 1),
                    opts=dict(
                        title='{}, epoch {}'.format(split.upper(), epoch)),
                    win=list(splits).index(split),
                )

            results[split] = meters.avgs()
            print('\n\nEpoch: {} {}: {}\t'.format(epoch, split, meters))

        # transposing the results dict
        res = {}
        legend = []
        for split in results:
            legend.append(split)
            for metric, avg in results[split].items():
                res.setdefault(metric, [])
                res[metric].append(avg)
        # plotting
        for metric in res:
            y = np.expand_dims(np.array(res[metric]), 0)
            x = np.array([[epoch] * len(results)])
            if epoch == 1:
                win = viz.line(X=x,
                               Y=y,
                               opts=dict(showlegend=True,
                                         legend=legend,
                                         title=metric))
                viz_wins[metric] = win
            else:
                viz.line(X=x,
                         Y=y,
                         opts=dict(showlegend=True,
                                   legend=legend,
                                   title=metric),
                         win=viz_wins[metric],
                         update='append')

        if (epoch % args.save_every == 0) and (epoch >= args.save_start):
            to_save = {
                'epoch': epoch,
                'estimator': estimator,
                'warp': warp,
                'optim': optimizer,
                'err_obs': results['test']['pl'],
                'err_aae': results['test']['err_aae'],
            }
            time_str = datetime.now().strftime("%a-%b-%d-%H:%M:%S.%f")
            save_path = os.path.join(args.save_root, f'{time_str}_{epoch}.pt')
            print(f'Saving modules to {save_path} ...')
            torch.save(to_save, save_path)