示例#1
0
    def cb(solver, progress, batch_index, result):
        iteration = (current_task - 1) * total_iterations + batch_index
        progress.set_description(
            ('<Training Solver>    '
             'task: {task}/{tasks} | '
             'progress: [{trained}/{total}] ({percentage:.0f}%) | '
             'loss: {loss:.4} | '
             'prec: {prec:.4}').format(
                 task=current_task,
                 tasks=total_tasks,
                 trained=batch_size * batch_index,
                 total=batch_size * total_iterations,
                 percentage=(100. * batch_index / total_iterations),
                 loss=result['loss'],
                 prec=result['precision'],
             ))

        # log the loss of the solver.
        if iteration % loss_log_interval == 0:
            visual.visualize_scalar(result['loss'],
                                    'solver loss',
                                    iteration,
                                    env=env)

        # evaluate the solver on multiple tasks.
        if iteration % eval_log_interval == 0:
            names = [
                'task {}'.format(i + 1) for i in range(len(test_datasets))
            ]
            precs = [
                utils.validate(solver,
                               test_datasets[i],
                               test_size=test_size,
                               cuda=cuda,
                               verbose=False,
                               collate_fn=collate_fn,
                               train=False if valid_proportion > 0 else True,
                               valid_proportion=valid_proportion)
                if i + 1 <= current_task else 0
                for i in range(len(test_datasets))
            ]
            title = 'precision ({replay_mode})'.format(replay_mode=replay_mode)
            visual.visualize_scalars(precs, names, title, iteration, env=env)
示例#2
0
def train_model(model,
                dataset,
                epochs=10,
                batch_size=32,
                sample_size=32,
                lr=3e-04,
                weight_decay=1e-5,
                loss_log_interval=30,
                image_log_interval=300,
                checkpoint_dir='./checkpoints',
                resume=False,
                cuda=False):
    # prepare optimizer and model
    model.train()
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay,
    )

    if resume:
        epoch_start = utils.load_checkpoint(model, checkpoint_dir)
    else:
        epoch_start = 1

    for epoch in range(epoch_start, epochs + 1):
        data_loader = utils.get_data_loader(dataset, batch_size, cuda=cuda)
        data_stream = tqdm(enumerate(data_loader, 1))

        for batch_index, (x, _) in data_stream:
            # where are we?
            iteration = (epoch - 1) * (len(dataset) //
                                       batch_size) + batch_index

            # prepare data on gpu if needed
            x = Variable(x).cuda() if cuda else Variable(x)

            # flush gradients and run the model forward
            optimizer.zero_grad()
            (mean, logvar), x_reconstructed = model(x)
            reconstruction_loss = model.reconstruction_loss(x_reconstructed, x)
            kl_divergence_loss = model.kl_divergence_loss(mean, logvar)
            total_loss = reconstruction_loss + kl_divergence_loss

            # backprop gradients from the loss
            total_loss.backward()
            optimizer.step()

            # update progress
            data_stream.set_description(
                ('epoch: {epoch} | '
                 'iteration: {iteration} | '
                 'progress: [{trained}/{total}] ({progress:.0f}%) | '
                 'loss => '
                 'total: {total_loss:.4f} / '
                 're: {reconstruction_loss:.3f} / '
                 'kl: {kl_divergence_loss:.3f}').format(
                     epoch=epoch,
                     iteration=iteration,
                     trained=batch_index * len(x),
                     total=len(data_loader.dataset),
                     progress=(100. * batch_index / len(data_loader)),
                     total_loss=total_loss.data.item(),
                     reconstruction_loss=reconstruction_loss.data.item(),
                     kl_divergence_loss=kl_divergence_loss.data.item(),
                 ))

            if iteration % loss_log_interval == 0:
                losses = [
                    reconstruction_loss.data.item(),
                    kl_divergence_loss.data.item(),
                    total_loss.data.item(),
                ]
                names = ['reconstruction', 'kl divergence', 'total']
                visual.visualize_scalars(losses,
                                         names,
                                         'loss',
                                         iteration,
                                         env=model.name)

            if iteration % image_log_interval == 0:
                images = model.sample(sample_size)
                visual.visualize_images(images,
                                        'generated samples',
                                        env=model.name)

        # notify that we've reached to a new checkpoint.
        print()
        print()
        print('#############')
        print('# checkpoint!')
        print('#############')
        print()

        # save the checkpoint.
        utils.save_checkpoint(model, checkpoint_dir, epoch)
        print()
示例#3
0
def train(model, train_datasets, test_datasets, epochs_per_task=10,
          batch_size=64, test_size=1024, consolidate=True,
          fisher_estimation_sample_size=1024,
          lr=1e-3, weight_decay=1e-5, lamda=3,
          loss_log_interval=30,
          eval_log_interval=50,
          cuda=False):
    # prepare the loss criteriton and the optimizer.
    criteriton = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr,
                           weight_decay=weight_decay)

    # set the model's mode to training mode.
    model.train()

    for task, train_dataset in enumerate(train_datasets, 1):
        for epoch in range(1, epochs_per_task+1):
            # prepare the data loaders.
            data_loader = utils.get_data_loader(
                train_dataset, batch_size=batch_size,
                cuda=cuda
            )
            data_stream = tqdm(enumerate(data_loader, 1))

            for batch_index, (x, y) in data_stream:
                # where are we?
                data_size = len(x)
                dataset_size = len(data_loader.dataset)
                dataset_batches = len(data_loader)
                previous_task_iteration = sum([
                    epochs_per_task * len(d) // batch_size for d in
                    train_datasets[:task-1]
                ])
                current_task_iteration = (
                    (epoch-1)*dataset_batches + batch_index
                )
                iteration = (
                    previous_task_iteration +
                    current_task_iteration
                )

                # prepare the data.
                x = x.view(data_size, -1)
                x = Variable(x).cuda() if cuda else Variable(x)
                y = Variable(y).cuda() if cuda else Variable(y)

                # run the model and backpropagate the errors.
                optimizer.zero_grad()
                scores = model(x)
                ce_loss = criteriton(scores, y)
                ewc_loss = model.ewc_loss(lamda, cuda=cuda)
                loss = ce_loss + ewc_loss
                loss.backward()
                optimizer.step()

                # calculate the training precision.
                _, predicted = scores.max(1)
                precision = (predicted == y).sum().data[0] / len(x)

                data_stream.set_description((
                    'task: {task}/{tasks} | '
                    'epoch: {epoch}/{epochs} | '
                    'progress: [{trained}/{total}] ({progress:.0f}%) | '
                    'prec: {prec:.4} | '
                    'loss => '
                    'ce: {ce_loss:.4} / '
                    'ewc: {ewc_loss:.4} / '
                    'total: {loss:.4}'
                ).format(
                    task=task,
                    tasks=len(train_datasets),
                    epoch=epoch,
                    epochs=epochs_per_task,
                    trained=batch_index*batch_size,
                    total=dataset_size,
                    progress=(100.*batch_index/dataset_batches),
                    prec=precision,
                    ce_loss=ce_loss.data[0],
                    ewc_loss=ewc_loss.data[0],
                    loss=loss.data[0],
                ))

                # Send test precision to the visdom server.
                if iteration % eval_log_interval == 0:
                    names = [
                        'task {}'.format(i+1) for i in
                        range(len(train_datasets))
                    ]
                    precs = [
                        utils.validate(
                            model, test_datasets[i], test_size=test_size,
                            cuda=cuda, verbose=False,
                        ) if i+1 <= task else 0 for i in
                        range(len(train_datasets))
                    ]
                    title = (
                        'precision (consolidated)' if consolidate else
                        'precision'
                    )
                    visual.visualize_scalars(
                        precs, names, title,
                        iteration, env=model.name,
                    )

                # Send losses to the visdom server.
                if iteration % loss_log_interval == 0:
                    title = 'loss (consolidated)' if consolidate else 'loss'
                    visual.visualize_scalars(
                        [loss.data, ce_loss.data, ewc_loss.data],
                        ['total', 'cross entropy', 'ewc'],
                        title, iteration, env=model.name
                    )

        if consolidate:
            # estimate the fisher information of the parameters and consolidate
            # them in the network.
            model.consolidate(model.estimate_fisher(
                train_dataset, fisher_estimation_sample_size
            ))
示例#4
0
def train(model,
          train_dataset,
          test_dataset=None,
          model_dir='models',
          lr=1e-04,
          lr_decay=.1,
          lr_decay_epochs=None,
          weight_decay=1e-04,
          gamma1=1.,
          gamma2=1.,
          gamma3=10.,
          batch_size=32,
          test_size=256,
          epochs=5,
          eval_log_interval=30,
          loss_log_interval=30,
          weight_log_interval=500,
          checkpoint_interval=500,
          resume_best=False,
          resume_latest=False,
          cuda=False):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),
                           lr=lr,
                           weight_decay=weight_decay)
    scheduler = MultiStepLR(optimizer, lr_decay_epochs, gamma=lr_decay)

    # prepare the model and statistics.
    model.train()
    epoch_start = 1
    best_precision = 0

    # load checkpoint if needed.
    if resume_latest or resume_best:
        epoch_start, best_precision = utils.load_checkpoint(model,
                                                            model_dir,
                                                            best=resume_best)

    for epoch in range(epoch_start, epochs + 1):
        # adjust learning rate if needed.
        scheduler.step(epoch - 1)

        # prepare a data stream for the epoch.
        data_loader = utils.get_data_loader(train_dataset,
                                            batch_size,
                                            cuda=cuda)
        data_stream = tqdm(enumerate(data_loader, 1))

        for batch_index, (data, labels) in data_stream:
            # where are we?
            data_size = len(data)
            dataset_size = len(data_loader.dataset)
            dataset_batches = len(data_loader)
            iteration = ((epoch - 1) *
                         (len(data_loader.dataset) // batch_size) +
                         batch_index + 1)

            # clear the gradients.
            optimizer.zero_grad()

            # run the network.
            x = Variable(data).cuda() if cuda else Variable(data)
            labels = Variable(labels).cuda() if cuda else Variable(labels)
            scores = model(x)
            _, predicted = scores.max(1)
            precision = (labels == predicted).sum().data[0] / data_size

            # update the network.
            cross_entropy_loss = criterion(scores, labels)
            overlap_loss, uniform_loss, split_loss = model.reg_loss()
            overlap_loss *= gamma1
            uniform_loss *= gamma3
            split_loss *= gamma2
            reg_loss = overlap_loss + uniform_loss + split_loss

            total_loss = cross_entropy_loss + reg_loss
            total_loss.backward(retain_graph=True)
            optimizer.step()

            # update & display statistics.
            data_stream.set_description(
                ('epoch: {epoch}/{epochs} | '
                 'it: {iteration} | '
                 'progress: [{trained}/{total}] ({progress:.0f}%) | '
                 'prec: {prec:.3} | '
                 'loss => '
                 'ce: {ce_loss:.4} / '
                 'reg: {reg_loss:.4} / '
                 'total: {total_loss:.4}').format(
                     epoch=epoch,
                     epochs=epochs,
                     iteration=iteration,
                     trained=(batch_index + 1) * batch_size,
                     total=dataset_size,
                     progress=(100. * (batch_index + 1) / dataset_batches),
                     prec=precision,
                     ce_loss=(cross_entropy_loss.data[0] / data_size),
                     reg_loss=(reg_loss.data[0] / data_size),
                     total_loss=(total_loss.data[0] / data_size),
                 ))

            # Send test precision to the visdom server.
            if iteration % eval_log_interval == 0:
                visual.visualize_scalar(utils.validate(model,
                                                       test_dataset,
                                                       test_size=test_size,
                                                       cuda=cuda,
                                                       verbose=False),
                                        'precision',
                                        iteration,
                                        env=model.name)

            # Send losses to the visdom server.
            if iteration % loss_log_interval == 0:
                reg_losses_and_names = ([
                    overlap_loss.data / data_size,
                    uniform_loss.data / data_size,
                    split_loss.data / data_size,
                    reg_loss.data / data_size,
                ], ['overlap', 'uniform', 'split', 'total'])

                visual.visualize_scalar(overlap_loss.data / data_size,
                                        'overlap loss',
                                        iteration,
                                        env=model.name)
                visual.visualize_scalar(uniform_loss.data / data_size,
                                        'uniform loss',
                                        iteration,
                                        env=model.name)
                visual.visualize_scalar(split_loss.data / data_size,
                                        'split loss',
                                        iteration,
                                        env=model.name)
                visual.visualize_scalars(*reg_losses_and_names,
                                         'regulaization losses',
                                         iteration,
                                         env=model.name)

                model_losses_and_names = ([
                    cross_entropy_loss.data / data_size,
                    reg_loss.data / data_size,
                    total_loss.data / data_size,
                ], ['cross entropy', 'regularization', 'total'])

                visual.visualize_scalar(cross_entropy_loss.data / data_size,
                                        'cross entropy loss',
                                        iteration,
                                        env=model.name)

                visual.visualize_scalar(reg_loss.data / data_size,
                                        'regularization loss',
                                        iteration,
                                        env=model.name)

                visual.visualize_scalars(*model_losses_and_names,
                                         'model losses',
                                         iteration,
                                         env=model.name)

            if iteration % weight_log_interval == 0:
                # Send visualized weights to the visdom server.
                weights = [
                    (w.data, p, q)
                    for i, g in enumerate(model.residual_block_groups)
                    for b in g.residual_blocks for w, p, q in (
                        (b.w1, b.p(), b.r()),
                        (b.w2, b.r(), b.q()),
                        (b.w3, b.p(), b.q()),
                    )
                    if i + 1 > (len(model.residual_block_groups) -
                                (len(model.split_sizes) - 1)) and w is not None
                ] + [(model.fc.linear.weight.data, model.fc.p(), model.fc.q())]

                names = [
                    'g{i}-b{j}-w{k}'.format(i=i + 1, j=j + 1, k=k + 1)
                    for i, g in enumerate(model.residual_block_groups)
                    for j, b in enumerate(g.residual_blocks)
                    for k, w in enumerate((b.w1, b.w2, b.w3))
                    if i + 1 > (len(model.residual_block_groups) -
                                (len(model.split_sizes) - 1)) and w is not None
                ] + ['fc-w']

                for (w, p, q), name in zip(weights, names):
                    visual.visualize_kernel(
                        splits.block_diagonalize_kernel(w, p, q),
                        name,
                        label='epoch{}-{}'.format(epoch, batch_index + 1),
                        update_window_without_label=True,
                        env=model.name,
                    )

                # Send visualized split indicators to the visdom server.
                indicators = [
                    q.data for i, g in enumerate(model.residual_block_groups)
                    for j, b in enumerate(g.residual_blocks)
                    for k, q in enumerate((b.p(), b.r())) if q is not None
                ] + [model.fc.p().data, model.fc.q().data]

                names = [
                    'g{i}-b{j}-{indicator}'.format(
                        i=i + 1, j=j + 1, indicator=ind)
                    for i, g in enumerate(model.residual_block_groups)
                    for j, b in enumerate(g.residual_blocks)
                    for ind, q in zip(('p', 'r'), (b.p(), b.r()))
                    if q is not None
                ] + ['fc-p', 'fc-q']

                for q, name in zip(indicators, names):
                    # Stretch the split indicators before visualization.
                    q_diagonalized = splits.block_diagonalize_indacator(q)
                    q_diagonalized_expanded = q_diagonalized\
                        .view(*q.size(), 1)\
                        .repeat(1, 20, 1)\
                        .view(-1, q.size()[1])

                    visual.visualize_kernel(q_diagonalized_expanded,
                                            name,
                                            label='epoch{}-{}'.format(
                                                epoch, batch_index + 1),
                                            update_window_without_label=True,
                                            env=model.name,
                                            w=100,
                                            h=100)

            if iteration % checkpoint_interval == 0:
                # notify that we've reached to a new checkpoint.
                print()
                print()
                print('#############')
                print('# checkpoint!')
                print('#############')
                print()

                # test the model.
                model_precision = utils.validate(model,
                                                 test_dataset or train_dataset,
                                                 test_size=test_size,
                                                 cuda=cuda,
                                                 verbose=True)

                # update best precision if needed.
                is_best = model_precision > best_precision
                best_precision = max(model_precision, best_precision)

                # save the checkpoint.
                utils.save_checkpoint(model,
                                      model_dir,
                                      epoch,
                                      model_precision,
                                      best=is_best)
                print()
示例#5
0
def train(model,
          train_dataset,
          test_dataset=None,
          collate_fn=None,
          model_dir='models',
          lr=1e-3,
          lr_decay=.1,
          lr_decay_epochs=None,
          weight_decay=1e-04,
          grad_clip_norm=10.,
          batch_size=32,
          test_size=256,
          epochs=5,
          eval_log_interval=30,
          gradient_log_interval=50,
          loss_log_interval=30,
          checkpoint_interval=500,
          resume_best=False,
          resume_latest=False,
          cuda=True):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay,
    )
    scheduler = MultiStepLR(optimizer, lr_decay_epochs, gamma=lr_decay)

    model.train()
    epoch_start = 1
    best_precision = 0

    if resume_best or resume_latest:
        epoch_start, best_precision = utils.load_checkpoint(model,
                                                            model_dir,
                                                            best=resume_best)

    for epoch in range(epoch_start, epochs + 1):
        scheduler.step(epoch - 1)
        data_loader = utils.get_data_loader(train_dataset,
                                            batch_size,
                                            cuda=cuda,
                                            collate_fn=collate_fn)
        data_stream = tqdm(enumerate(data_loader, 1))

        for batch_index, (x, q, a) in data_stream:
            # where are we?
            data_size = len(x)
            dataset_size = len(data_loader.dataset)
            dataset_batches = len(data_loader)
            iteration = ((epoch - 1) * (dataset_size // batch_size) +
                         batch_index + 1)

            x = Variable(x).cuda() if cuda else Variable(x)
            q = Variable(q).cuda() if cuda else Variable(q)
            a = Variable(a).cuda() if cuda else Variable(a)

            optimizer.zero_grad()
            scores = model(x, q)
            loss = criterion(scores, a)
            loss.backward()

            _, predicted = scores.max(1)
            precision = (predicted == a).sum().data[0] / data_size

            if grad_clip_norm:
                nn.utils.clip_grad_norm(model.parameters(), grad_clip_norm)
            optimizer.step()

            # update & display statistics.
            data_stream.set_description(
                ('epoch: {epoch}/{epochs} | '
                 'total iteration: {iteration} | '
                 'progress: [{trained}/{total}] ({progress:.0f}%) | '
                 'prec: {prec:.4} | '
                 'loss: {loss:.4} ').format(
                     epoch=epoch,
                     epochs=epochs,
                     iteration=iteration,
                     trained=(batch_index + 1) * batch_size,
                     total=dataset_size,
                     progress=(100. * (batch_index + 1) / dataset_batches),
                     prec=precision,
                     loss=loss.data[0],
                 ))

            # Send gradient norms to the visdom server.
            if iteration % gradient_log_interval == 0:
                names, gradients = zip(*[(n, p.grad.norm().data)
                                         for n, p in model.named_parameters()])
                visual.visualize_scalars(gradients,
                                         names,
                                         'gradient l2 norms',
                                         iteration,
                                         env=model.name)

            # Send test precision to the visdom server.
            if iteration % eval_log_interval == 0:
                visual.visualize_scalar(utils.validate(model,
                                                       test_dataset,
                                                       test_size=test_size,
                                                       cuda=cuda,
                                                       collate_fn=collate_fn,
                                                       verbose=False),
                                        'precision',
                                        iteration,
                                        env=model.name)

            # Send losses to the visdom server.
            if iteration % loss_log_interval == 0:
                visual.visualize_scalar(loss.data / data_size,
                                        'loss',
                                        iteration,
                                        env=model.name)

            if iteration % checkpoint_interval == 0:
                # notify that we've reached to a new checkpoint.
                print()
                print()
                print('#############')
                print('# checkpoint!')
                print('#############')
                print()

                # test the model.
                model_precision = utils.validate(model,
                                                 test_dataset or train_dataset,
                                                 test_size=test_size,
                                                 cuda=cuda,
                                                 collate_fn=collate_fn,
                                                 verbose=True)

                # update best precision if needed.
                is_best = model_precision > best_precision
                best_precision = max(model_precision, best_precision)

                # save the checkpoint.
                utils.save_checkpoint(model,
                                      model_dir,
                                      epoch,
                                      model_precision,
                                      best=is_best)
                print()