コード例 #1
0
ファイル: train.py プロジェクト: nabeelhthussain/ECCO
def evaluation():
    am = ArgoverseMap()

    val_dataset = read_pkl_data(val_path,
                                batch_size=args.val_batch_size,
                                shuffle=False,
                                repeat=False)

    trained_model = torch.load(model_name + '.pth')
    trained_model.eval()

    with torch.no_grad():
        valid_total_loss, valid_metrics = evaluate(
            trained_model,
            val_dataset,
            train_window=args.train_window,
            max_iter=len(val_dataset),
            device=device,
            start_iter=args.val_batches,
            use_lane=args.use_lane,
            batch_size=args.val_batch_size)

    with open('results/{}_predictions.pickle'.format(model_name), 'wb') as f:
        pickle.dump(valid_metrics, f)
コード例 #2
0
def main():

    device = torch.device('cuda')

    val_dataset = read_data_val(files=val_files, window=1, cache_data=True)

    dataset = read_data_train(files=train_files,
                              batch_size=train_params.batch_size,
                              window=3,
                              random_rotation=True,
                              num_workers=2)
    data_iter = iter(dataset)

    trainer = Trainer(train_dir)

    model = create_model()
    model.to(device)

    boundaries = [
        25 * _k,
        30 * _k,
        35 * _k,
        40 * _k,
        45 * _k,
    ]
    lr_values = [
        1.0,
        0.5,
        0.25,
        0.125,
        0.5 * 0.125,
        0.25 * 0.125,
    ]

    def lrfactor_fn(x):
        factor = lr_values[0]
        for b, v in zip(boundaries, lr_values[1:]):
            if x > b:
                factor = v
            else:
                break
        return factor

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=train_params.base_lr,
                                 eps=1e-6)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lrfactor_fn)

    step = torch.tensor(0)
    checkpoint_fn = lambda: {
        'step': step,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }

    manager = MyCheckpointManager(checkpoint_fn,
                                  trainer.checkpoint_dir,
                                  keep_checkpoint_steps=list(
                                      range(1 * _k, train_params.max_iter + 1,
                                            1 * _k)))

    def euclidean_distance(a, b, epsilon=1e-9):
        return torch.sqrt(torch.sum((a - b)**2, dim=-1) + epsilon)

    def loss_fn(pr_pos, gt_pos, num_fluid_neighbors):
        gamma = 0.5
        neighbor_scale = 1 / 40
        importance = torch.exp(-neighbor_scale * num_fluid_neighbors)
        return torch.mean(importance *
                          euclidean_distance(pr_pos, gt_pos)**gamma)

    def train(model, batch):
        optimizer.zero_grad()
        losses = []

        batch_size = train_params.batch_size
        for batch_i in range(batch_size):
            inputs = ([
                batch['pos0'][batch_i], batch['vel0'][batch_i], None,
                batch['box'][batch_i], batch['box_normals'][batch_i]
            ])

            pr_pos1, pr_vel1 = model(inputs)

            l = 0.5 * loss_fn(pr_pos1, batch['pos1'][batch_i],
                              model.num_fluid_neighbors)

            inputs = (pr_pos1, pr_vel1, None, batch['box'][batch_i],
                      batch['box_normals'][batch_i])
            pr_pos2, pr_vel2 = model(inputs)

            l += 0.5 * loss_fn(pr_pos2, batch['pos2'][batch_i],
                               model.num_fluid_neighbors)
            losses.append(l)

        total_loss = 128 * sum(losses) / batch_size
        total_loss.backward()
        optimizer.step()

        return total_loss

    if manager.latest_checkpoint:
        print('restoring from ', manager.latest_checkpoint)
        latest_checkpoint = torch.load(manager.latest_checkpoint)
        step = latest_checkpoint['step']
        model.load_state_dict(latest_checkpoint['model'])
        optimizer.load_state_dict(latest_checkpoint['optimizer'])
        scheduler.load_state_dict(latest_checkpoint['scheduler'])

    display_str_list = []
    while trainer.keep_training(step,
                                train_params.max_iter,
                                checkpoint_manager=manager,
                                display_str_list=display_str_list):

        data_fetch_start = time.time()
        batch = next(data_iter)
        batch_torch = {}
        for k in ('pos0', 'vel0', 'pos1', 'pos2', 'box', 'box_normals'):
            batch_torch[k] = [torch.from_numpy(x).to(device) for x in batch[k]]
        data_fetch_latency = time.time() - data_fetch_start
        trainer.log_scalar_every_n_minutes(5, 'DataLatency',
                                           data_fetch_latency)

        current_loss = train(model, batch_torch)
        scheduler.step()
        display_str_list = ['loss', float(current_loss)]

        if trainer.current_step % 10 == 0:
            trainer.summary_writer.add_scalar('TotalLoss', current_loss,
                                              trainer.current_step)
            trainer.summary_writer.add_scalar('LearningRate',
                                              scheduler.get_last_lr()[0],
                                              trainer.current_step)

        if trainer.current_step % (1 * _k) == 0:
            for k, v in evaluate(model,
                                 val_dataset,
                                 frame_skip=20,
                                 device=device).items():
                trainer.summary_writer.add_scalar('eval/' + k, v,
                                                  trainer.current_step)

    torch.save({'model': model.state_dict()}, 'model_weights.pt')
    if trainer.current_step == train_params.max_iter:
        return trainer.STATUS_TRAINING_FINISHED
    else:
        return trainer.STATUS_TRAINING_UNFINISHED
コード例 #3
0
def main():

    val_dataset = read_data_val(files=val_files, window=1, cache_data=True)

    dataset = read_data_train(files=train_files,
                              batch_size=train_params.batch_size,
                              window=3,
                              random_rotation=True,
                              num_workers=2)
    data_iter = iter(dataset)

    trainer = Trainer(train_dir)

    model = create_model()

    boundaries = [
        25 * _k,
        30 * _k,
        35 * _k,
        40 * _k,
        45 * _k,
    ]
    lr_values = [
        train_params.base_lr * 1.0,
        train_params.base_lr * 0.5,
        train_params.base_lr * 0.25,
        train_params.base_lr * 0.125,
        train_params.base_lr * 0.5 * 0.125,
        train_params.base_lr * 0.25 * 0.125,
    ]
    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries, lr_values)
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn,
                                         epsilon=1e-6)

    checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                     model=model,
                                     optimizer=optimizer)

    manager = MyCheckpointManager(checkpoint,
                                  trainer.checkpoint_dir,
                                  keep_checkpoint_steps=list(
                                      range(1 * _k, train_params.max_iter + 1,
                                            1 * _k)))

    def euclidean_distance(a, b, epsilon=1e-9):
        return tf.sqrt(tf.reduce_sum((a - b)**2, axis=-1) + epsilon)

    def loss_fn(pr_pos, gt_pos, num_fluid_neighbors):
        gamma = 0.5
        neighbor_scale = 1 / 40
        importance = tf.exp(-neighbor_scale * num_fluid_neighbors)
        return tf.reduce_mean(importance *
                              euclidean_distance(pr_pos, gt_pos)**gamma)

    @tf.function(experimental_relax_shapes=True)
    def train(model, batch):
        with tf.GradientTape() as tape:
            losses = []

            batch_size = train_params.batch_size
            for batch_i in range(batch_size):
                inputs = ([
                    batch['pos0'][batch_i], batch['vel0'][batch_i], None,
                    batch['box'][batch_i], batch['box_normals'][batch_i]
                ])

                pr_pos1, pr_vel1 = model(inputs)

                l = 0.5 * loss_fn(pr_pos1, batch['pos1'][batch_i],
                                  model.num_fluid_neighbors)

                inputs = (pr_pos1, pr_vel1, None, batch['box'][batch_i],
                          batch['box_normals'][batch_i])
                pr_pos2, pr_vel2 = model(inputs)

                l += 0.5 * loss_fn(pr_pos2, batch['pos2'][batch_i],
                                   model.num_fluid_neighbors)
                losses.append(l)

            losses.extend(model.losses)
            total_loss = 128 * tf.add_n(losses) / batch_size

            grads = tape.gradient(total_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return total_loss

    if manager.latest_checkpoint:
        print('restoring from ', manager.latest_checkpoint)
        checkpoint.restore(manager.latest_checkpoint)

    display_str_list = []
    while trainer.keep_training(checkpoint.step,
                                train_params.max_iter,
                                checkpoint_manager=manager,
                                display_str_list=display_str_list):

        data_fetch_start = time.time()
        batch = next(data_iter)
        batch_tf = {}
        for k in ('pos0', 'vel0', 'pos1', 'pos2', 'box', 'box_normals'):
            batch_tf[k] = [tf.convert_to_tensor(x) for x in batch[k]]
        data_fetch_latency = time.time() - data_fetch_start
        trainer.log_scalar_every_n_minutes(5, 'DataLatency', data_fetch_latency)

        current_loss = train(model, batch_tf)
        display_str_list = ['loss', float(current_loss)]

        if trainer.current_step % 10 == 0:
            with trainer.summary_writer.as_default():
                tf.summary.scalar('TotalLoss', current_loss)
                tf.summary.scalar('LearningRate',
                                  optimizer.lr(trainer.current_step))

        if trainer.current_step % (1 * _k) == 0:
            for k, v in evaluate(model, val_dataset, frame_skip=20).items():
                with trainer.summary_writer.as_default():
                    tf.summary.scalar('eval/' + k, v)

    model.save_weights('model_weights.h5')
    if trainer.current_step == train_params.max_iter:
        return trainer.STATUS_TRAINING_FINISHED
    else:
        return trainer.STATUS_TRAINING_UNFINISHED
コード例 #4
0
def main():
    parser = argparse.ArgumentParser(description="Training script")
    parser.add_argument("cfg",
                        type=str,
                        help="The path to the yaml config file")
    if len(sys.argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()

    with open(args.cfg, 'r') as f:
        cfg = yaml.safe_load(f)

    # the train dir stores all checkpoints and summaries. The dir name is the name of this file combined with the name of the config file
    train_dir = os.path.splitext(
        os.path.basename(__file__))[0] + '_' + os.path.splitext(
            os.path.basename(args.cfg))[0]

    val_files = sorted(glob(os.path.join(cfg['dataset_dir'], 'valid',
                                         '*.zst')))
    train_files = sorted(
        glob(os.path.join(cfg['dataset_dir'], 'train', '*.zst')))

    val_dataset = read_data_val(files=val_files, window=1, cache_data=True)

    dataset = read_data_train(files=train_files,
                              batch_size=train_params.batch_size,
                              window=3,
                              num_workers=2,
                              **cfg.get('train_data', {}))
    data_iter = iter(dataset)

    trainer = Trainer(train_dir)

    model = create_model(**cfg.get('model', {}))

    boundaries = [
        25 * _k,
        30 * _k,
        35 * _k,
        40 * _k,
        45 * _k,
    ]
    lr_values = [
        train_params.base_lr * 1.0,
        train_params.base_lr * 0.5,
        train_params.base_lr * 0.25,
        train_params.base_lr * 0.125,
        train_params.base_lr * 0.5 * 0.125,
        train_params.base_lr * 0.25 * 0.125,
    ]
    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries, lr_values)
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn,
                                         epsilon=1e-6)

    checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                     model=model,
                                     optimizer=optimizer)

    manager = MyCheckpointManager(checkpoint,
                                  trainer.checkpoint_dir,
                                  keep_checkpoint_steps=list(
                                      range(1 * _k, train_params.max_iter + 1,
                                            1 * _k)))

    def euclidean_distance(a, b, epsilon=1e-9):
        return tf.sqrt(tf.reduce_sum((a - b)**2, axis=-1) + epsilon)

    def loss_fn(pr_pos, gt_pos, num_fluid_neighbors):
        gamma = 0.5
        neighbor_scale = 1 / 40
        importance = tf.exp(-neighbor_scale * num_fluid_neighbors)
        return tf.reduce_mean(importance *
                              euclidean_distance(pr_pos, gt_pos)**gamma)

    @tf.function(experimental_relax_shapes=True)
    def train(model, batch):
        with tf.GradientTape() as tape:
            losses = []

            batch_size = train_params.batch_size
            for batch_i in range(batch_size):
                inputs = ([
                    batch['pos0'][batch_i], batch['vel0'][batch_i], None,
                    batch['box'][batch_i], batch['box_normals'][batch_i]
                ])

                pr_pos1, pr_vel1 = model(inputs)

                l = 0.5 * loss_fn(pr_pos1, batch['pos1'][batch_i],
                                  model.num_fluid_neighbors)

                inputs = (pr_pos1, pr_vel1, None, batch['box'][batch_i],
                          batch['box_normals'][batch_i])
                pr_pos2, pr_vel2 = model(inputs)

                l += 0.5 * loss_fn(pr_pos2, batch['pos2'][batch_i],
                                   model.num_fluid_neighbors)
                losses.append(l)

            losses.extend(model.losses)
            total_loss = 128 * tf.add_n(losses) / batch_size

            grads = tape.gradient(total_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return total_loss

    if manager.latest_checkpoint:
        print('restoring from ', manager.latest_checkpoint)
        checkpoint.restore(manager.latest_checkpoint)

    display_str_list = []
    while trainer.keep_training(checkpoint.step,
                                train_params.max_iter,
                                checkpoint_manager=manager,
                                display_str_list=display_str_list):

        data_fetch_start = time.time()
        batch = next(data_iter)
        batch_tf = {}
        for k in ('pos0', 'vel0', 'pos1', 'pos2', 'box', 'box_normals'):
            batch_tf[k] = [tf.convert_to_tensor(x) for x in batch[k]]
        data_fetch_latency = time.time() - data_fetch_start
        trainer.log_scalar_every_n_minutes(5, 'DataLatency',
                                           data_fetch_latency)

        current_loss = train(model, batch_tf)
        display_str_list = ['loss', float(current_loss)]

        if trainer.current_step % 10 == 0:
            with trainer.summary_writer.as_default():
                tf.summary.scalar('TotalLoss', current_loss)
                tf.summary.scalar('LearningRate',
                                  optimizer.lr(trainer.current_step))

        if trainer.current_step % (1 * _k) == 0:
            for k, v in evaluate(model,
                                 val_dataset,
                                 frame_skip=20,
                                 **cfg.get('evaluation', {})).items():
                with trainer.summary_writer.as_default():
                    tf.summary.scalar('eval/' + k, v)

    model.save_weights('model_weights.h5')
    if trainer.current_step == train_params.max_iter:
        return trainer.STATUS_TRAINING_FINISHED
    else:
        return trainer.STATUS_TRAINING_UNFINISHED
コード例 #5
0
ファイル: train.py プロジェクト: nabeelhthussain/ECCO
def train():
    am = ArgoverseMap()

    val_dataset = read_pkl_data(val_path,
                                batch_size=args.val_batch_size,
                                shuffle=True,
                                repeat=False,
                                max_lane_nodes=700)

    dataset = read_pkl_data(train_path,
                            batch_size=args.batch_size // args.batch_divide,
                            repeat=True,
                            shuffle=True,
                            max_lane_nodes=900)

    data_iter = iter(dataset)

    model = create_model().to(device)
    # model_ = torch.load(model_name + '.pth')
    # model = model_
    model = MyDataParallel(model)
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.base_lr,
                                 betas=(0.9, 0.999),
                                 weight_decay=4e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=1,
                                                gamma=0.95)

    def train_one_batch(model, batch, train_window=2):

        batch_size = args.batch_size

        inputs = ([
            batch['pos_2s'], batch['vel_2s'], batch['pos0'], batch['vel0'],
            batch['accel'], None, batch['lane'], batch['lane_norm'],
            batch['car_mask'], batch['lane_mask']
        ])

        # print_inputs_shape(inputs)
        # print(batch['pos0'])
        pr_pos1, pr_vel1, states = model(inputs)
        gt_pos1 = batch['pos1']
        # print(pr_pos1)

        # losses = 0.5 * loss_fn(pr_pos1, gt_pos1, model.num_fluid_neighbors.unsqueeze(-1), batch['car_mask'])
        losses = 0.5 * loss_fn(pr_pos1, gt_pos1,
                               torch.sum(batch['car_mask'], dim=-2) - 1,
                               batch['car_mask'].squeeze(-1))
        del gt_pos1

        pos0 = batch['pos0']
        vel0 = batch['vel0']
        for i in range(train_window - 1):
            pos_enc = torch.unsqueeze(pos0, 2)
            vel_enc = torch.unsqueeze(vel0, 2)
            inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, batch['accel'], None,
                      batch['lane'], batch['lane_norm'], batch['car_mask'],
                      batch['lane_mask'])
            pos0, vel0 = pr_pos1, pr_vel1
            # del pos_enc, vel_enc

            pr_pos1, pr_vel1, states = model(inputs, states)
            gt_pos1 = batch['pos' + str(i + 2)]

            losses += 0.5 * loss_fn(pr_pos1, gt_pos1,
                                    torch.sum(batch['car_mask'], dim=-2) - 1,
                                    batch['car_mask'].squeeze(-1))

        total_loss = 128 * torch.sum(losses, axis=0) / batch_size

        return total_loss

    epochs = args.epochs
    batches_per_epoch = args.batches_per_epoch  # batchs_per_epoch.  Dataset is too large to run whole data.
    data_load_times = []  # Per batch
    train_losses = []
    valid_losses = []
    valid_metrics_list = []
    min_loss = None

    for i in range(epochs):
        epoch_start_time = time.time()

        model.train()
        epoch_train_loss = 0
        sub_idx = 0

        print("training ... epoch " + str(i + 1), end='')
        for batch_itr in range(batches_per_epoch * args.batch_divide):

            data_fetch_start = time.time()
            batch = next(data_iter)

            if sub_idx == 0:
                optimizer.zero_grad()
                if (batch_itr // args.batch_divide) % 25 == 0:
                    print("... batch " +
                          str((batch_itr // args.batch_divide) + 1),
                          end='',
                          flush=True)
            sub_idx += 1

            batch_size = len(batch['pos0'])

            batch_tensor = {}
            convert_keys = (
                ['pos' + str(i) for i in range(args.train_window + 1)] +
                ['vel' + str(i) for i in range(args.train_window + 1)] +
                ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])

            for k in convert_keys:
                batch_tensor[k] = torch.tensor(np.stack(batch[k]),
                                               dtype=torch.float32,
                                               device=device)

            for k in ['car_mask', 'lane_mask']:
                batch_tensor[k] = torch.tensor(np.stack(batch[k]),
                                               dtype=torch.float32,
                                               device=device).unsqueeze(-1)

            for k in ['track_id' + str(i) for i in range(31)] + ['city']:
                batch_tensor[k] = batch[k]

            batch_tensor['car_mask'] = batch_tensor['car_mask'].squeeze(-1)
            accel = torch.zeros(batch_size, 1, 2).to(device)
            batch_tensor['accel'] = accel
            del batch

            data_fetch_latency = time.time() - data_fetch_start
            data_load_times.append(data_fetch_latency)

            current_loss = train_one_batch(model,
                                           batch_tensor,
                                           train_window=args.train_window)

            if sub_idx < args.batch_divide:
                current_loss.backward(retain_graph=True)
            else:
                current_loss.backward()
                optimizer.step()
                sub_idx = 0
            del batch_tensor

            epoch_train_loss += float(current_loss)
            del current_loss
            clean_cache(device)

            if batch_itr == batches_per_epoch - 1:
                print("... DONE", flush=True)

        train_losses.append(epoch_train_loss)

        model.eval()
        with torch.no_grad():
            valid_total_loss, valid_metrics = evaluate(
                model.module,
                val_dataset,
                train_window=args.train_window,
                max_iter=args.val_batches,
                device=device,
                batch_size=args.val_batch_size)

        valid_losses.append(float(valid_total_loss))
        valid_metrics_list.append(valid_metrics)

        if min_loss is None:
            min_loss = valid_losses[-1]

        if valid_losses[-1] < min_loss:
            print('update weights')
            min_loss = valid_losses[-1]
            best_model = model
            torch.save(model.module, model_name + ".pth")

        epoch_end_time = time.time()

        print(
            'epoch: {}, train loss: {}, val loss: {}, epoch time: {}, lr: {}, {}'
            .format(i + 1, train_losses[-1], valid_losses[-1],
                    round((epoch_end_time - epoch_start_time) / 60, 5),
                    format(get_lr(optimizer), "5.2e"), model_name))

        scheduler.step()
コード例 #6
0
def main():
    parser = argparse.ArgumentParser(description="Training script")
    parser.add_argument("cfg",
                        type=str,
                        help="The path to the yaml config file")
    if len(sys.argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)
    args = parser.parse_args()

    with open(args.cfg, 'r') as f:
        cfg = yaml.safe_load(f)

    # the train dir stores all checkpoints and summaries. The dir name is the name of this file combined with the name of the config file
    train_dir = os.path.splitext(
        os.path.basename(__file__))[0] + '_' + os.path.splitext(
            os.path.basename(args.cfg))[0]

    val_files = sorted(glob(os.path.join(cfg['dataset_dir'], 'valid',
                                         '*.zst')))
    train_files = sorted(
        glob(os.path.join(cfg['dataset_dir'], 'train', '*.zst')))

    device = torch.device('cuda')

    val_dataset = read_data_val(files=val_files, window=1, cache_data=True)

    dataset = read_data_train(files=train_files,
                              batch_size=train_params.batch_size,
                              window=3,
                              num_workers=2,
                              **cfg.get('train_data', {}))
    data_iter = iter(dataset)

    trainer = Trainer(train_dir)

    model = create_model(**cfg.get('model', {}))
    model.to(device)

    boundaries = [
        25 * _k,
        30 * _k,
        35 * _k,
        40 * _k,
        45 * _k,
    ]
    lr_values = [
        1.0,
        0.5,
        0.25,
        0.125,
        0.5 * 0.125,
        0.25 * 0.125,
    ]

    def lrfactor_fn(x):
        factor = lr_values[0]
        for b, v in zip(boundaries, lr_values[1:]):
            if x > b:
                factor = v
            else:
                break
        return factor

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=train_params.base_lr,
                                 eps=1e-6)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lrfactor_fn)

    step = torch.tensor(0)
    checkpoint_fn = lambda: {
        'step': step,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }

    manager = MyCheckpointManager(checkpoint_fn,
                                  trainer.checkpoint_dir,
                                  keep_checkpoint_steps=list(
                                      range(1 * _k, train_params.max_iter + 1,
                                            1 * _k)))

    def euclidean_distance(a, b, epsilon=1e-9):
        return torch.sqrt(torch.sum((a - b)**2, dim=-1) + epsilon)

    def loss_fn(pr_pos, gt_pos, num_fluid_neighbors):
        gamma = 0.5
        neighbor_scale = 1 / 40
        importance = torch.exp(-neighbor_scale * num_fluid_neighbors)
        return torch.mean(importance *
                          euclidean_distance(pr_pos, gt_pos)**gamma)

    def train(model, batch):
        optimizer.zero_grad()
        losses = []

        batch_size = train_params.batch_size
        for batch_i in range(batch_size):
            inputs = ([
                batch['pos0'][batch_i], batch['vel0'][batch_i], None,
                batch['box'][batch_i], batch['box_normals'][batch_i]
            ])

            pr_pos1, pr_vel1 = model(inputs)

            l = 0.5 * loss_fn(pr_pos1, batch['pos1'][batch_i],
                              model.num_fluid_neighbors)

            inputs = (pr_pos1, pr_vel1, None, batch['box'][batch_i],
                      batch['box_normals'][batch_i])
            pr_pos2, pr_vel2 = model(inputs)

            l += 0.5 * loss_fn(pr_pos2, batch['pos2'][batch_i],
                               model.num_fluid_neighbors)
            losses.append(l)

        total_loss = 128 * sum(losses) / batch_size
        total_loss.backward()
        optimizer.step()

        return total_loss

    if manager.latest_checkpoint:
        print('restoring from ', manager.latest_checkpoint)
        latest_checkpoint = torch.load(manager.latest_checkpoint)
        step = latest_checkpoint['step']
        model.load_state_dict(latest_checkpoint['model'])
        optimizer.load_state_dict(latest_checkpoint['optimizer'])
        scheduler.load_state_dict(latest_checkpoint['scheduler'])

    display_str_list = []
    while trainer.keep_training(step,
                                train_params.max_iter,
                                checkpoint_manager=manager,
                                display_str_list=display_str_list):

        data_fetch_start = time.time()
        batch = next(data_iter)
        batch_torch = {}
        for k in ('pos0', 'vel0', 'pos1', 'pos2', 'box', 'box_normals'):
            batch_torch[k] = [torch.from_numpy(x).to(device) for x in batch[k]]
        data_fetch_latency = time.time() - data_fetch_start
        trainer.log_scalar_every_n_minutes(5, 'DataLatency',
                                           data_fetch_latency)

        current_loss = train(model, batch_torch)
        scheduler.step()
        display_str_list = ['loss', float(current_loss)]

        if trainer.current_step % 10 == 0:
            trainer.summary_writer.add_scalar('TotalLoss', current_loss,
                                              trainer.current_step)
            trainer.summary_writer.add_scalar('LearningRate',
                                              scheduler.get_last_lr()[0],
                                              trainer.current_step)

        if trainer.current_step % (1 * _k) == 0:
            for k, v in evaluate(model,
                                 val_dataset,
                                 frame_skip=20,
                                 device=device,
                                 **cfg.get('evaluation', {})).items():
                trainer.summary_writer.add_scalar('eval/' + k, v,
                                                  trainer.current_step)

    torch.save({'model': model.state_dict()}, 'model_weights.pt')
    if trainer.current_step == train_params.max_iter:
        return trainer.STATUS_TRAINING_FINISHED
    else:
        return trainer.STATUS_TRAINING_UNFINISHED