def cb(solver, progress, batch_index, result): iteration = (current_task - 1) * total_iterations + batch_index progress.set_description( ('<Training Solver> ' 'task: {task}/{tasks} | ' 'progress: [{trained}/{total}] ({percentage:.0f}%) | ' 'loss: {loss:.4} | ' 'prec: {prec:.4}').format( task=current_task, tasks=total_tasks, trained=batch_size * batch_index, total=batch_size * total_iterations, percentage=(100. * batch_index / total_iterations), loss=result['loss'], prec=result['precision'], )) # log the loss of the solver. if iteration % loss_log_interval == 0: visual.visualize_scalar(result['loss'], 'solver loss', iteration, env=env) # evaluate the solver on multiple tasks. if iteration % eval_log_interval == 0: names = [ 'task {}'.format(i + 1) for i in range(len(test_datasets)) ] precs = [ utils.validate(solver, test_datasets[i], test_size=test_size, cuda=cuda, verbose=False, collate_fn=collate_fn, train=False if valid_proportion > 0 else True, valid_proportion=valid_proportion) if i + 1 <= current_task else 0 for i in range(len(test_datasets)) ] title = 'precision ({replay_mode})'.format(replay_mode=replay_mode) visual.visualize_scalars(precs, names, title, iteration, env=env)
def train_model(model, dataset, epochs=10, batch_size=32, sample_size=32, lr=3e-04, weight_decay=1e-5, loss_log_interval=30, image_log_interval=300, checkpoint_dir='./checkpoints', resume=False, cuda=False): # prepare optimizer and model model.train() optimizer = optim.Adam( model.parameters(), lr=lr, weight_decay=weight_decay, ) if resume: epoch_start = utils.load_checkpoint(model, checkpoint_dir) else: epoch_start = 1 for epoch in range(epoch_start, epochs + 1): data_loader = utils.get_data_loader(dataset, batch_size, cuda=cuda) data_stream = tqdm(enumerate(data_loader, 1)) for batch_index, (x, _) in data_stream: # where are we? iteration = (epoch - 1) * (len(dataset) // batch_size) + batch_index # prepare data on gpu if needed x = Variable(x).cuda() if cuda else Variable(x) # flush gradients and run the model forward optimizer.zero_grad() (mean, logvar), x_reconstructed = model(x) reconstruction_loss = model.reconstruction_loss(x_reconstructed, x) kl_divergence_loss = model.kl_divergence_loss(mean, logvar) total_loss = reconstruction_loss + kl_divergence_loss # backprop gradients from the loss total_loss.backward() optimizer.step() # update progress data_stream.set_description( ('epoch: {epoch} | ' 'iteration: {iteration} | ' 'progress: [{trained}/{total}] ({progress:.0f}%) | ' 'loss => ' 'total: {total_loss:.4f} / ' 're: {reconstruction_loss:.3f} / ' 'kl: {kl_divergence_loss:.3f}').format( epoch=epoch, iteration=iteration, trained=batch_index * len(x), total=len(data_loader.dataset), progress=(100. * batch_index / len(data_loader)), total_loss=total_loss.data.item(), reconstruction_loss=reconstruction_loss.data.item(), kl_divergence_loss=kl_divergence_loss.data.item(), )) if iteration % loss_log_interval == 0: losses = [ reconstruction_loss.data.item(), kl_divergence_loss.data.item(), total_loss.data.item(), ] names = ['reconstruction', 'kl divergence', 'total'] visual.visualize_scalars(losses, names, 'loss', iteration, env=model.name) if iteration % image_log_interval == 0: images = model.sample(sample_size) visual.visualize_images(images, 'generated samples', env=model.name) # notify that we've reached to a new checkpoint. print() print() print('#############') print('# checkpoint!') print('#############') print() # save the checkpoint. utils.save_checkpoint(model, checkpoint_dir, epoch) print()
def train(model, train_datasets, test_datasets, epochs_per_task=10, batch_size=64, test_size=1024, consolidate=True, fisher_estimation_sample_size=1024, lr=1e-3, weight_decay=1e-5, lamda=3, loss_log_interval=30, eval_log_interval=50, cuda=False): # prepare the loss criteriton and the optimizer. criteriton = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) # set the model's mode to training mode. model.train() for task, train_dataset in enumerate(train_datasets, 1): for epoch in range(1, epochs_per_task+1): # prepare the data loaders. data_loader = utils.get_data_loader( train_dataset, batch_size=batch_size, cuda=cuda ) data_stream = tqdm(enumerate(data_loader, 1)) for batch_index, (x, y) in data_stream: # where are we? data_size = len(x) dataset_size = len(data_loader.dataset) dataset_batches = len(data_loader) previous_task_iteration = sum([ epochs_per_task * len(d) // batch_size for d in train_datasets[:task-1] ]) current_task_iteration = ( (epoch-1)*dataset_batches + batch_index ) iteration = ( previous_task_iteration + current_task_iteration ) # prepare the data. x = x.view(data_size, -1) x = Variable(x).cuda() if cuda else Variable(x) y = Variable(y).cuda() if cuda else Variable(y) # run the model and backpropagate the errors. optimizer.zero_grad() scores = model(x) ce_loss = criteriton(scores, y) ewc_loss = model.ewc_loss(lamda, cuda=cuda) loss = ce_loss + ewc_loss loss.backward() optimizer.step() # calculate the training precision. _, predicted = scores.max(1) precision = (predicted == y).sum().data[0] / len(x) data_stream.set_description(( 'task: {task}/{tasks} | ' 'epoch: {epoch}/{epochs} | ' 'progress: [{trained}/{total}] ({progress:.0f}%) | ' 'prec: {prec:.4} | ' 'loss => ' 'ce: {ce_loss:.4} / ' 'ewc: {ewc_loss:.4} / ' 'total: {loss:.4}' ).format( task=task, tasks=len(train_datasets), epoch=epoch, epochs=epochs_per_task, trained=batch_index*batch_size, total=dataset_size, progress=(100.*batch_index/dataset_batches), prec=precision, ce_loss=ce_loss.data[0], ewc_loss=ewc_loss.data[0], loss=loss.data[0], )) # Send test precision to the visdom server. if iteration % eval_log_interval == 0: names = [ 'task {}'.format(i+1) for i in range(len(train_datasets)) ] precs = [ utils.validate( model, test_datasets[i], test_size=test_size, cuda=cuda, verbose=False, ) if i+1 <= task else 0 for i in range(len(train_datasets)) ] title = ( 'precision (consolidated)' if consolidate else 'precision' ) visual.visualize_scalars( precs, names, title, iteration, env=model.name, ) # Send losses to the visdom server. if iteration % loss_log_interval == 0: title = 'loss (consolidated)' if consolidate else 'loss' visual.visualize_scalars( [loss.data, ce_loss.data, ewc_loss.data], ['total', 'cross entropy', 'ewc'], title, iteration, env=model.name ) if consolidate: # estimate the fisher information of the parameters and consolidate # them in the network. model.consolidate(model.estimate_fisher( train_dataset, fisher_estimation_sample_size ))
def train(model, train_dataset, test_dataset=None, model_dir='models', lr=1e-04, lr_decay=.1, lr_decay_epochs=None, weight_decay=1e-04, gamma1=1., gamma2=1., gamma3=10., batch_size=32, test_size=256, epochs=5, eval_log_interval=30, loss_log_interval=30, weight_log_interval=500, checkpoint_interval=500, resume_best=False, resume_latest=False, cuda=False): criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = MultiStepLR(optimizer, lr_decay_epochs, gamma=lr_decay) # prepare the model and statistics. model.train() epoch_start = 1 best_precision = 0 # load checkpoint if needed. if resume_latest or resume_best: epoch_start, best_precision = utils.load_checkpoint(model, model_dir, best=resume_best) for epoch in range(epoch_start, epochs + 1): # adjust learning rate if needed. scheduler.step(epoch - 1) # prepare a data stream for the epoch. data_loader = utils.get_data_loader(train_dataset, batch_size, cuda=cuda) data_stream = tqdm(enumerate(data_loader, 1)) for batch_index, (data, labels) in data_stream: # where are we? data_size = len(data) dataset_size = len(data_loader.dataset) dataset_batches = len(data_loader) iteration = ((epoch - 1) * (len(data_loader.dataset) // batch_size) + batch_index + 1) # clear the gradients. optimizer.zero_grad() # run the network. x = Variable(data).cuda() if cuda else Variable(data) labels = Variable(labels).cuda() if cuda else Variable(labels) scores = model(x) _, predicted = scores.max(1) precision = (labels == predicted).sum().data[0] / data_size # update the network. cross_entropy_loss = criterion(scores, labels) overlap_loss, uniform_loss, split_loss = model.reg_loss() overlap_loss *= gamma1 uniform_loss *= gamma3 split_loss *= gamma2 reg_loss = overlap_loss + uniform_loss + split_loss total_loss = cross_entropy_loss + reg_loss total_loss.backward(retain_graph=True) optimizer.step() # update & display statistics. data_stream.set_description( ('epoch: {epoch}/{epochs} | ' 'it: {iteration} | ' 'progress: [{trained}/{total}] ({progress:.0f}%) | ' 'prec: {prec:.3} | ' 'loss => ' 'ce: {ce_loss:.4} / ' 'reg: {reg_loss:.4} / ' 'total: {total_loss:.4}').format( epoch=epoch, epochs=epochs, iteration=iteration, trained=(batch_index + 1) * batch_size, total=dataset_size, progress=(100. * (batch_index + 1) / dataset_batches), prec=precision, ce_loss=(cross_entropy_loss.data[0] / data_size), reg_loss=(reg_loss.data[0] / data_size), total_loss=(total_loss.data[0] / data_size), )) # Send test precision to the visdom server. if iteration % eval_log_interval == 0: visual.visualize_scalar(utils.validate(model, test_dataset, test_size=test_size, cuda=cuda, verbose=False), 'precision', iteration, env=model.name) # Send losses to the visdom server. if iteration % loss_log_interval == 0: reg_losses_and_names = ([ overlap_loss.data / data_size, uniform_loss.data / data_size, split_loss.data / data_size, reg_loss.data / data_size, ], ['overlap', 'uniform', 'split', 'total']) visual.visualize_scalar(overlap_loss.data / data_size, 'overlap loss', iteration, env=model.name) visual.visualize_scalar(uniform_loss.data / data_size, 'uniform loss', iteration, env=model.name) visual.visualize_scalar(split_loss.data / data_size, 'split loss', iteration, env=model.name) visual.visualize_scalars(*reg_losses_and_names, 'regulaization losses', iteration, env=model.name) model_losses_and_names = ([ cross_entropy_loss.data / data_size, reg_loss.data / data_size, total_loss.data / data_size, ], ['cross entropy', 'regularization', 'total']) visual.visualize_scalar(cross_entropy_loss.data / data_size, 'cross entropy loss', iteration, env=model.name) visual.visualize_scalar(reg_loss.data / data_size, 'regularization loss', iteration, env=model.name) visual.visualize_scalars(*model_losses_and_names, 'model losses', iteration, env=model.name) if iteration % weight_log_interval == 0: # Send visualized weights to the visdom server. weights = [ (w.data, p, q) for i, g in enumerate(model.residual_block_groups) for b in g.residual_blocks for w, p, q in ( (b.w1, b.p(), b.r()), (b.w2, b.r(), b.q()), (b.w3, b.p(), b.q()), ) if i + 1 > (len(model.residual_block_groups) - (len(model.split_sizes) - 1)) and w is not None ] + [(model.fc.linear.weight.data, model.fc.p(), model.fc.q())] names = [ 'g{i}-b{j}-w{k}'.format(i=i + 1, j=j + 1, k=k + 1) for i, g in enumerate(model.residual_block_groups) for j, b in enumerate(g.residual_blocks) for k, w in enumerate((b.w1, b.w2, b.w3)) if i + 1 > (len(model.residual_block_groups) - (len(model.split_sizes) - 1)) and w is not None ] + ['fc-w'] for (w, p, q), name in zip(weights, names): visual.visualize_kernel( splits.block_diagonalize_kernel(w, p, q), name, label='epoch{}-{}'.format(epoch, batch_index + 1), update_window_without_label=True, env=model.name, ) # Send visualized split indicators to the visdom server. indicators = [ q.data for i, g in enumerate(model.residual_block_groups) for j, b in enumerate(g.residual_blocks) for k, q in enumerate((b.p(), b.r())) if q is not None ] + [model.fc.p().data, model.fc.q().data] names = [ 'g{i}-b{j}-{indicator}'.format( i=i + 1, j=j + 1, indicator=ind) for i, g in enumerate(model.residual_block_groups) for j, b in enumerate(g.residual_blocks) for ind, q in zip(('p', 'r'), (b.p(), b.r())) if q is not None ] + ['fc-p', 'fc-q'] for q, name in zip(indicators, names): # Stretch the split indicators before visualization. q_diagonalized = splits.block_diagonalize_indacator(q) q_diagonalized_expanded = q_diagonalized\ .view(*q.size(), 1)\ .repeat(1, 20, 1)\ .view(-1, q.size()[1]) visual.visualize_kernel(q_diagonalized_expanded, name, label='epoch{}-{}'.format( epoch, batch_index + 1), update_window_without_label=True, env=model.name, w=100, h=100) if iteration % checkpoint_interval == 0: # notify that we've reached to a new checkpoint. print() print() print('#############') print('# checkpoint!') print('#############') print() # test the model. model_precision = utils.validate(model, test_dataset or train_dataset, test_size=test_size, cuda=cuda, verbose=True) # update best precision if needed. is_best = model_precision > best_precision best_precision = max(model_precision, best_precision) # save the checkpoint. utils.save_checkpoint(model, model_dir, epoch, model_precision, best=is_best) print()
def train(model, train_dataset, test_dataset=None, collate_fn=None, model_dir='models', lr=1e-3, lr_decay=.1, lr_decay_epochs=None, weight_decay=1e-04, grad_clip_norm=10., batch_size=32, test_size=256, epochs=5, eval_log_interval=30, gradient_log_interval=50, loss_log_interval=30, checkpoint_interval=500, resume_best=False, resume_latest=False, cuda=True): criterion = nn.CrossEntropyLoss() optimizer = optim.Adam( model.parameters(), lr=lr, weight_decay=weight_decay, ) scheduler = MultiStepLR(optimizer, lr_decay_epochs, gamma=lr_decay) model.train() epoch_start = 1 best_precision = 0 if resume_best or resume_latest: epoch_start, best_precision = utils.load_checkpoint(model, model_dir, best=resume_best) for epoch in range(epoch_start, epochs + 1): scheduler.step(epoch - 1) data_loader = utils.get_data_loader(train_dataset, batch_size, cuda=cuda, collate_fn=collate_fn) data_stream = tqdm(enumerate(data_loader, 1)) for batch_index, (x, q, a) in data_stream: # where are we? data_size = len(x) dataset_size = len(data_loader.dataset) dataset_batches = len(data_loader) iteration = ((epoch - 1) * (dataset_size // batch_size) + batch_index + 1) x = Variable(x).cuda() if cuda else Variable(x) q = Variable(q).cuda() if cuda else Variable(q) a = Variable(a).cuda() if cuda else Variable(a) optimizer.zero_grad() scores = model(x, q) loss = criterion(scores, a) loss.backward() _, predicted = scores.max(1) precision = (predicted == a).sum().data[0] / data_size if grad_clip_norm: nn.utils.clip_grad_norm(model.parameters(), grad_clip_norm) optimizer.step() # update & display statistics. data_stream.set_description( ('epoch: {epoch}/{epochs} | ' 'total iteration: {iteration} | ' 'progress: [{trained}/{total}] ({progress:.0f}%) | ' 'prec: {prec:.4} | ' 'loss: {loss:.4} ').format( epoch=epoch, epochs=epochs, iteration=iteration, trained=(batch_index + 1) * batch_size, total=dataset_size, progress=(100. * (batch_index + 1) / dataset_batches), prec=precision, loss=loss.data[0], )) # Send gradient norms to the visdom server. if iteration % gradient_log_interval == 0: names, gradients = zip(*[(n, p.grad.norm().data) for n, p in model.named_parameters()]) visual.visualize_scalars(gradients, names, 'gradient l2 norms', iteration, env=model.name) # Send test precision to the visdom server. if iteration % eval_log_interval == 0: visual.visualize_scalar(utils.validate(model, test_dataset, test_size=test_size, cuda=cuda, collate_fn=collate_fn, verbose=False), 'precision', iteration, env=model.name) # Send losses to the visdom server. if iteration % loss_log_interval == 0: visual.visualize_scalar(loss.data / data_size, 'loss', iteration, env=model.name) if iteration % checkpoint_interval == 0: # notify that we've reached to a new checkpoint. print() print() print('#############') print('# checkpoint!') print('#############') print() # test the model. model_precision = utils.validate(model, test_dataset or train_dataset, test_size=test_size, cuda=cuda, collate_fn=collate_fn, verbose=True) # update best precision if needed. is_best = model_precision > best_precision best_precision = max(model_precision, best_precision) # save the checkpoint. utils.save_checkpoint(model, model_dir, epoch, model_precision, best=is_best) print()