Exemple #1
0
def main():
    parser = argparse.ArgumentParser()

    default_dataset = 'toy-data.npz'
    parser.add_argument('--data', default=default_dataset, help='data file')
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz',
                        type=int,
                        default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch',
                        type=int,
                        default=1000,
                        help='number of training epochs')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=8e-5,
                        help='encoder/decoder learning rate')
    parser.add_argument('--dis-lr',
                        type=float,
                        default=1e-4,
                        help='discriminator learning rate')
    parser.add_argument('--min-lr',
                        type=float,
                        default=5e-5,
                        help='min encoder/decoder learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--min-dis-lr',
                        type=float,
                        default=7e-5,
                        help='min discriminator learning rate for LR '
                        'scheduler. -1 to disable annealing')
    parser.add_argument('--wd', type=float, default=0, help='weight decay')
    parser.add_argument('--overlap',
                        type=float,
                        default=.5,
                        help='kernel overlap')
    parser.add_argument('--no-norm-trans',
                        action='store_true',
                        help='if set, use Gaussian posterior without '
                        'transformation')
    parser.add_argument('--plot-interval',
                        type=int,
                        default=1,
                        help='plot interval. 0 to disable plotting.')
    parser.add_argument('--save-interval',
                        type=int,
                        default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix',
                        default='pbigan',
                        help='prefix of output directory')
    parser.add_argument('--comp',
                        type=int,
                        default=7,
                        help='continuous convolution kernel size')
    parser.add_argument('--ae',
                        type=float,
                        default=.2,
                        help='autoencoding regularization strength')
    parser.add_argument('--aeloss',
                        default='smooth_l1',
                        help='autoencoding loss. (options: mse, smooth_l1)')
    parser.add_argument('--ema',
                        dest='ema',
                        type=int,
                        default=-1,
                        help='start epoch of exponential moving average '
                        '(EMA). -1 to disable EMA')
    parser.add_argument('--ema-decay',
                        type=float,
                        default=.9999,
                        help='EMA decay')
    parser.add_argument('--mmd',
                        type=float,
                        default=1,
                        help='MMD strength for latent variable')

    # squash is off when rescale is off
    parser.add_argument('--squash',
                        dest='squash',
                        action='store_const',
                        const=True,
                        default=True,
                        help='bound the generated time series value '
                        'using tanh')
    parser.add_argument('--no-squash',
                        dest='squash',
                        action='store_const',
                        const=False)

    # rescale to [-1, 1]
    parser.add_argument('--rescale',
                        dest='rescale',
                        action='store_const',
                        const=True,
                        default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale',
                        dest='rescale',
                        action='store_const',
                        const=False)

    args = parser.parse_args()

    batch_size = args.batch_size
    nz = args.nz

    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval

    try:
        npz = np.load(args.data)
        train_data = npz['data']
        train_time = npz['time']
        train_mask = npz['mask']
    except FileNotFoundError:
        if args.data != default_dataset:
            raise
        # Generate the default toy dataset from scratch
        train_data, train_time, train_mask, _, _ = gen_data(
            n_samples=10000,
            seq_len=200,
            max_time=1,
            poisson_rate=50,
            obs_span_rate=.25,
            save_file=default_dataset)

    _, in_channels, seq_len = train_data.shape
    train_time *= train_mask

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Scale time
    max_time = 5
    train_time *= max_time

    squash = None
    rescaler = None
    if args.rescale:
        rescaler = Rescaler(train_data)
        train_data = rescaler.rescale(train_data)
        if args.squash:
            squash = torch.tanh

    out_channels = 64
    cconv_ref = 98

    train_dataset = TimeSeries(train_data,
                               train_time,
                               train_mask,
                               label=None,
                               max_time=max_time,
                               cconv_ref=cconv_ref,
                               overlap_rate=args.overlap,
                               device=device)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    time_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             drop_last=True,
                             collate_fn=train_dataset.collate_fn)

    test_loader = DataLoader(train_dataset,
                             batch_size=batch_size,
                             collate_fn=train_dataset.collate_fn)

    grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
    decoder = Decoder(grid_decoder, max_time=max_time).to(device)

    cconv = ContinuousConv1D(in_channels,
                             out_channels,
                             max_time,
                             cconv_ref,
                             overlap_rate=args.overlap,
                             kernel_size=args.comp,
                             norm=True).to(device)
    encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device)

    pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)

    critic_cconv = ContinuousConv1D(in_channels,
                                    out_channels,
                                    max_time,
                                    cconv_ref,
                                    overlap_rate=args.overlap,
                                    kernel_size=args.comp,
                                    norm=True).to(device)
    critic = ConvCritic(critic_cconv, nz).to(device)

    ema = None
    if args.ema >= 0:
        ema = EMA(pbigan, args.ema_decay, args.ema)

    optimizer = optim.Adam(pbigan.parameters(),
                           lr=args.lr,
                           weight_decay=args.wd)
    critic_optimizer = optim.Adam(critic.parameters(),
                                  lr=args.dis_lr,
                                  weight_decay=args.wd)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
    dis_scheduler = make_scheduler(critic_optimizer, args.dis_lr,
                                   args.min_dis_lr, epochs)

    path = '{}_{}'.format(args.prefix, datetime.now().strftime('%m%d.%H%M%S'))

    output_dir = Path('results') / 'toy-pbigan' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)

    tracker = Tracker(log_dir, n_train_batch)
    visualizer = Visualizer(encoder, decoder, batch_size, max_time,
                            test_loader, rescaler, output_dir, device)
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)

        for ((val, idx, mask, _, cconv_graph),
             (_, idx_t, mask_t, index, _)) in zip(train_loader, time_loader):

            z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan(
                val, idx, mask, cconv_graph, idx_t, mask_t)

            cconv_graph_gen = train_dataset.make_graph(x_gen, idx_t, mask_t,
                                                       index)

            real = critic(cconv_graph, batch_size, z_enc)
            fake = critic(cconv_graph_gen, batch_size, z_gen)

            D_loss = gan_loss(real, fake, 1, 0)

            critic_optimizer.zero_grad()
            D_loss.backward(retain_graph=True)
            critic_optimizer.step()

            G_loss = gan_loss(real, fake, 0, 1)

            mmd_loss = mmd(z_enc, z_gen)

            loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if ema:
                ema.update()

            loss_breakdown['D'] += D_loss.item()
            loss_breakdown['G'] += G_loss.item()
            loss_breakdown['AE'] += ae_loss.item()
            loss_breakdown['MMD'] += mmd_loss.item()
            loss_breakdown['total'] += loss.item()

        if scheduler:
            scheduler.step()
        if dis_scheduler:
            dis_scheduler.step()

        cur_time = time.time()
        tracker.log(epoch, loss_breakdown, cur_time - epoch_start,
                    cur_time - start)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            if ema:
                ema.apply()
                visualizer.plot(epoch)
                ema.restore()
            else:
                visualizer.plot(epoch)

        model_dict = {
            'pbigan': pbigan.state_dict(),
            'critic': critic.state_dict(),
            'ema': ema.state_dict() if ema else None,
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
Exemple #2
0
    print dagger_costs
    print "\n\n\n"
    im_accs = np.array([im_accs])
    dagger_accs = np.array([dagger_accs])


    data_dir = './data/raw/'
    np.save(data_dir + 'sup_costs', sup_costs)
    np.save(data_dir + 'im_costs', im_costs)
    np.save(data_dir + 'dagger_costs', dagger_costs)

    filename = data_directory + 'cost_comparisons.eps'
    vis.get_perf(sup_costs)
    vis.get_perf(im_costs)
    vis.get_perf(dagger_costs)
    vis.plot(["Supervisor", "Supervised", "DAgger"], "Cost", filename)

    filename = data_directory + 'surrogate_loss_comparisons.eps'
    vis.get_perf(im_loss_data)
    vis.get_perf(dagger_loss_data)
    vis.plot(["Supervised", "DAgger"], "Loss", filename)



    filename = data_directory + 'accuracy_comparisons.eps'
    vis.get_perf(im_acc_data)
    vis.get_perf(dagger_acc_data)
    vis.plot(['Supervised', 'DAgger'], "Acc", filename)


    """
Exemple #3
0
def train(**kwargs):
    # opt = Config()
    for k, v in kwargs.items():
        setattr(opt, k, v)
    vis = Visualizer(opt.env, opt.port)
    device = t.device('cuda') if opt.use_gpu else t.device('cpu')
    lr = opt.lr

    #网络配置
    featurenet = FeatureNet(4, 5)
    if opt.model_path:
        featurenet.load_state_dict(
            t.load(opt.model_path, map_location=lambda _s, _: _s))
    featurenet.to(device)

    #加载数据
    data_set = dataset.FeatureDataset(root=opt.data_root,
                                      train=True,
                                      test=False)
    dataloader = DataLoader(data_set,
                            batch_size=opt.batch_size,
                            shuffle=True,
                            num_workers=opt.num_workers)
    val_dataset = dataset.FeatureDataset(root=opt.data_root,
                                         train=False,
                                         test=False)
    val_dataloader = DataLoader(val_dataset,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)
    #定义优化器和随时函数
    optimizer = t.optim.SGD(featurenet.parameters(), lr)
    criterion = t.nn.CrossEntropyLoss().to(device)

    #计算重要指标
    loss_meter = AverageValueMeter()

    #开始训练
    for epoch in range(opt.max_epoch):
        loss_meter.reset()
        for ii, (data, label) in enumerate(dataloader):
            feature = data.to(device)
            target = label.to(device)

            optimizer.zero_grad()
            prob = featurenet(feature)
            # print(prob)
            # print(target)
            loss = criterion(prob, target)
            loss.backward()
            optimizer.step()
            loss_meter.add(loss.item())

            if (ii + 1) % opt.plot_every:
                vis.plot('train_loss', loss_meter.value()[0])
                if os.path.exists(opt.debug_file):
                    import ipdb
                    ipdb.set_trace()
        t.save(
            featurenet.state_dict(),
            'checkpoints/{epoch}_{time}_{loss}.pth'.format(
                epoch=epoch,
                time=time.strftime('%m%d_%H_%M_%S'),
                loss=loss_meter.value()[0]))

        #验证和可视化
        accu, loss = val(featurenet, val_dataloader, criterion)
        featurenet.train()
        vis.plot('val_loss', loss)
        vis.log('epoch: {epoch}, loss: {loss}, accu: {accu}'.format(
            epoch=epoch, loss=loss, accu=accu))

        lr = lr * 0.9
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
def main():
    parser = argparse.ArgumentParser()

    default_dataset = 'toy-data.npz'
    parser.add_argument('--data', default=default_dataset,
                        help='data file')
    parser.add_argument('--seed', type=int, default=None,
                        help='random seed. Randomly set if not specified.')

    # training options
    parser.add_argument('--nz', type=int, default=32,
                        help='dimension of latent variable')
    parser.add_argument('--epoch', type=int, default=1000,
                        help='number of training epochs')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='batch size')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument('--min-lr', type=float, default=5e-5,
                        help='min learning rate for LR scheduler. '
                             '-1 to disable annealing')
    parser.add_argument('--plot-interval', type=int, default=10,
                        help='plot interval. 0 to disable plotting.')
    parser.add_argument('--save-interval', type=int, default=0,
                        help='interval to save models. 0 to disable saving.')
    parser.add_argument('--prefix', default='pvae',
                        help='prefix of output directory')
    parser.add_argument('--comp', type=int, default=5,
                        help='continuous convolution kernel size')
    parser.add_argument('--sigma', type=float, default=.2,
                        help='standard deviation for Gaussian likelihood')
    parser.add_argument('--overlap', type=float, default=.5,
                        help='kernel overlap')
    # squash is off when rescale is off
    parser.add_argument('--squash', dest='squash', action='store_const',
                        const=True, default=True,
                        help='bound the generated time series value '
                             'using tanh')
    parser.add_argument('--no-squash', dest='squash', action='store_const',
                        const=False)

    # rescale to [-1, 1]
    parser.add_argument('--rescale', dest='rescale', action='store_const',
                        const=True, default=True,
                        help='if set, rescale time to [-1, 1]')
    parser.add_argument('--no-rescale', dest='rescale', action='store_const',
                        const=False)

    args = parser.parse_args()

    batch_size = args.batch_size
    nz = args.nz

    epochs = args.epoch
    plot_interval = args.plot_interval
    save_interval = args.save_interval

    try:
        npz = np.load(args.data)
        train_data = npz['data']
        train_time = npz['time']
        train_mask = npz['mask']
    except FileNotFoundError:
        if args.data != default_dataset:
            raise
        # Generate the default toy dataset from scratch
        train_data, train_time, train_mask, _, _ = gen_data(
            n_samples=10000, seq_len=200, max_time=1, poisson_rate=50,
            obs_span_rate=.25, save_file=default_dataset)

    _, in_channels, seq_len = train_data.shape
    train_time *= train_mask

    if args.seed is None:
        rnd = np.random.RandomState(None)
        random_seed = rnd.randint(np.iinfo(np.uint32).max)
    else:
        random_seed = args.seed
    rnd = np.random.RandomState(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    # Scale time
    max_time = 5
    train_time *= max_time

    squash = None
    rescaler = None
    if args.rescale:
        rescaler = Rescaler(train_data)
        train_data = rescaler.rescale(train_data)
        if args.squash:
            squash = torch.tanh

    out_channels = 64
    cconv_ref = 98

    train_dataset = TimeSeries(
        train_data, train_time, train_mask, label=None, max_time=max_time,
        cconv_ref=cconv_ref, overlap_rate=args.overlap, device=device)

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        drop_last=True, collate_fn=train_dataset.collate_fn)
    n_train_batch = len(train_loader)

    test_batch_size = 64
    test_loader = DataLoader(train_dataset, batch_size=test_batch_size,
                             collate_fn=train_dataset.collate_fn)

    grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
    decoder = Decoder(grid_decoder, max_time=max_time).to(device)

    cconv = ContinuousConv1D(
        in_channels, out_channels, max_time, cconv_ref,
        overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device)

    encoder = Encoder(nz, cconv).to(device)

    pvae = PVAE(encoder, decoder, sigma=args.sigma).to(device)

    optimizer = optim.Adam(pvae.parameters(), lr=args.lr)

    scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)

    path = '{}_{}_{}'.format(
        args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
        '_'.join([f'lr_{args.lr:g}']))

    output_dir = Path('results') / 'toy-pvae' / path
    print(output_dir)
    log_dir = mkdir(output_dir / 'log')
    model_dir = mkdir(output_dir / 'model')

    start_epoch = 0

    with (log_dir / 'seed.txt').open('w') as f:
        print(random_seed, file=f)
    with (log_dir / 'gpu.txt').open('a') as f:
        print(torch.cuda.device_count(), start_epoch, file=f)
    with (log_dir / 'args.txt').open('w') as f:
        for key, val in sorted(vars(args).items()):
            print(f'{key}: {val}', file=f)

    tracker = Tracker(log_dir, n_train_batch)
    visualizer = Visualizer(encoder, decoder, test_batch_size, max_time,
                            test_loader, rescaler, output_dir, device)
    start = time.time()
    epoch_start = start

    for epoch in range(start_epoch, epochs):
        loss_breakdown = defaultdict(float)
        for val, idx, mask, _, cconv_graph in train_loader:
            optimizer.zero_grad()
            loss = pvae(val, idx, mask, cconv_graph)
            loss.backward()
            optimizer.step()
            loss_breakdown['loss'] += loss.item()

        if scheduler:
            scheduler.step()

        cur_time = time.time()
        tracker.log(
            epoch, loss_breakdown, cur_time - epoch_start, cur_time - start)

        if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
            visualizer.plot(epoch)

        model_dict = {
            'pvae': pvae.state_dict(),
            'epoch': epoch + 1,
            'args': args,
        }
        torch.save(model_dict, str(log_dir / 'model.pth'))
        if save_interval > 0 and (epoch + 1) % save_interval == 0:
            torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))

    print(output_dir)
Exemple #5
0
def train(**kwargs):
    for k_, v_ in kwargs.items():
        setattr(config, k_, v_)

    vis_ = Visualizer()

    # data
    train_dataset = get_data.Ali(config.train_path, 'train',
                                 config.feature_index_path)
    val_dataset = get_data.Ali(config.val_path, 'val',
                               config.feature_index_path)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=config.batch_size)

    # model
    model = deepfm.FNN(config.feature_index_path)
    print(model)

    # print('initializing...')
    # model.apply(weight_init)

    # testing
    if config.test_flag:
        test_dataset = get_data.Ali(config.test_path, 'test',
                                    config.feature_index_path)
        test_loader = DataLoader(dataset=test_dataset,
                                 batch_size=config.batch_size)
        model.load_state_dict(
            torch.load(os.path.join(config.model_path, '_best')))
        test(model, test_loader, config.output_path)

    # criterion and optimizer
    criterion = torch.nn.BCELoss()
    lr = config.lr
    optimizer = Adam(model.parameters(),
                     lr=lr,
                     betas=(config.beta1, config.beta2),
                     weight_decay=config.weight_decay)
    previous_loss = 1e6
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()

    # meters
    loss_meter = tnt.meter.AverageValueMeter()
    # class_err = tnt.meter.ClassErrorMeter()
    # confusion_matrix = tnt.meter.ConfusionMeter(2, normalized=True)

    # val(model, val_loader, criterion)
    # resume training
    start = 0
    if config.resume:
        model_epoch = [
            int(fname.split('_')[-1])
            for fname in os.listdir(config.model_path) if 'best' not in fname
        ]
        start = max(model_epoch)
        model.load_state_dict(
            torch.load(os.path.join(config.model_path, '_epoch_{start}')))
    if start >= config.epochs:
        print('Training already Done!')
        return

    # train
    print('start training...')
    for i in range(start, config.epochs):
        loss_meter.reset()
        # class_err.reset()
        # confusion_matrix.reset()
        for ii, (c_data, labels) in tqdm(enumerate(train_loader)):
            c_data = to_var(c_data)
            labels = to_var(labels).float()
            # labels = labels.view(-1, 1)

            pred = model(c_data)
            # print(pred, labels)
            loss = criterion(pred, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # meters update and visualize
            loss_meter.add(loss.data[0])
            # confusion_matrix.add(pred.data.squeeze(), labels.data.type(torch.LongTensor))

            if (ii + 1) % config.print_every == 0:
                vis_.plot('train_loss', loss_meter.value()[0])
                print(
                    f'''epochs: {i + 1}/{config.epochs} batch: {ii + 1}/{len(train_loader)}
							train_loss: {loss.data[0]}''')

        print('evaluating...')
        # train_cm = confusion_matrix.value()
        val_cm, val_accuracy, val_loss = val(model, val_loader, criterion)
        vis_.plot('val_loss', val_loss)
        vis_.log(f"epoch:{start + 1},lr:{lr},loss:{val_loss}")

        torch.save(model.state_dict(),
                   os.path.join(config.model_path, f'_epoch_{i}'))

        # update learning rate
        if loss_meter.value()[0] > previous_loss:
            torch.save(model.state_dict(),
                       os.path.join(config.model_path, '_best'))
            lr = lr * config.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        previous_loss = loss_meter.value()[0]
Exemple #6
0
class Trainer():
    def __init__(self, config):

        # data
        self.train_data = AnnPolygon(config, train=True)
        self.val_data = AnnPolygon(config, train=False)
        self.num_points = config.num_points
        self.gt_num_points = config.gt_num_points

        # model
        self.model = Snake(state_dim=config.batch_size,
                           feature_dim=6,
                           conv_type='dgrid')

        # Run on GPU/CPU
        if torch.cuda.is_available():
            self.device = torch.device(f"cuda:{config.multi_gpu[0]}")
        else:
            self.device = torch.device(f"cpu")

        # check if load from existing model
        if config.reload_model_path:
            self.model.load(config.reload_model_path)
        else:
            self.model.cuda()

        self.train_dataloader = DataLoader(self.train_data,
                                           config.batch_size,
                                           shuffle=True,
                                           num_workers=config.num_workers,
                                           drop_last=True)
        self.val_dataloader = DataLoader(self.val_data,
                                         config.batch_size,
                                         shuffle=True,
                                         num_workers=config.num_workers,
                                         drop_last=True)

        # optimizer
        self.criterion = dist_chamfer_2D.chamfer_2DDist()
        if config.optimizer == "adam":
            self.lr = config.learning_rate
            self.optimizer = Adam(self.model.parameters(), lr=self.lr)
        elif config.optimizer == "sgd":
            self.lr = config.learning_rate / 10
            self.optimizer = SGD(self.model.parameters(), lr=self.lr)

        self.train_loss = meter.AverageValueMeter()
        self.val_loss = meter.AverageValueMeter()

        self.epochs = config.epochs

        self.vis = Visualizer(config)

    def train(self):
        for epoch in range(self.epochs):
            self.train_loss.reset()
            # training iteration
            for i, (curve, GT) in enumerate(self.train_dataloader):
                # load
                GT_points = GT.type(torch.FloatTensor).cuda(device=self.device)
                curve = curve.type(torch.FloatTensor).cuda(device=self.device)
                coodinates = curve[:, :, :2]

                curve = curve.permute(0, 2, 1)
                coodinates = coodinates.permute(0, 2, 1)

                # feed into model
                offset = self.model(curve)
                new_coodinates = offset + coodinates

                dist1, dist2, _, _ = self.criterion(
                    GT_points, new_coodinates.permute(0, 2, 1))
                loss = torch.mean(dist1) / self.num_points + torch.mean(
                    dist2) / self.gt_num_points

                # backward
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                self.train_loss.add(loss.item())
                self.vis.plot('train_loss',
                              self.train_loss.value()[0],
                              i + epoch * len(self.train_data))

            self.model.eval()

            for i, (curve, GT) in enumerate(self.val_dataloader):
                # load data
                GT_points = GT.type(torch.FloatTensor).cuda(device=self.device)
                curve = curve.type(torch.FloatTensor).cuda(device=self.device)
                coodinates = curve[:, :, :2]

                curve = curve.permute(0, 2, 1)
                coodinates = coodinates.permute(0, 2, 1)

                # feed into model
                offset = self.model(curve)
                new_coodinates = offset + coodinates

                dist1, dist2, _, _ = self.criterion(
                    GT_points, new_coodinates.permute(0, 2, 1))
                loss = torch.mean(dist1) / self.num_points + torch.mean(
                    dist2) / self.gt_num_points

                self.val_loss.add(loss.item())
                self.vis.plot('val_loss',
                              self.val_loss.value()[0],
                              i + epoch * len(self.val_data))

            self.model.train()