def reinit_weights(results, net, net_init, test_loader, criterion, epoch, device): """For each convolutional layer in net, reinitialize its weights with the values of the corresponding weights in net_init. Store the test accuracy of the resulting model in results. results = { "epoch_id": [acc1, ..., a_l, ..., acc_L], } where a_l is the test accuracy of net with layer l replaced by the corresponding layer in net_init. """ logger = logging.getLogger('reinit_weights') if results is None: results = { "init_from_epoch" : {} } test_accs = [] net.eval() net_init.eval() block_id, stage_id = 0, 0 conv_counter = 0 init_type = "random initialization." if str(epoch) == "rand" else "weights of epoch {}.".format(epoch) for layer_id, (layer, layer_init) in enumerate(zip(net.features, net_init.features)): if isinstance(layer, nn.Conv2d): if layer.kernel_size == (1,1): continue conv_counter +=1 # replace layer logger.info("Reinitializing parameters of layer {} from {}".format(conv_counter, init_type)) weight_copy = layer.weight.clone() if layer.bias is not None: bias_copy = layer.bias.clone() layer.bias.data.copy_(layer_init.bias.data) layer.weight.data.copy_(layer_init.weight.data) _, top1_acc, _ = scores.evaluate(net, test_loader, criterion, device, topk=(1,5)) test_accs.append(top1_acc) # restore original parameter layer.weight.data.copy_(weight_copy.data) if layer.bias is not None: layer.bias.data.copy_(bias_copy.data) block_id += 1 elif isinstance(layer, nn.MaxPool2d): stage_id += 1 block_id = 0 try: _ = results["init_from_epoch"] except KeyError: results["init_from_epoch"] = {} results["init_from_epoch"][str(epoch)] = test_accs return results
def train(model, end_epoch, train_loader, optimizer, criterion, scheduler, device, snapshot_dirname, start_epoch=0, snapshot_every=0, val_loader=None, kill_plateaus=False, best_acc1=0, writer=None, snapshot_all_until=0, filename='net', train_acc=False): """Train the specified model according to user options. Args: model (nn.Module) -- the model to be trained end_epoch (int) -- maximum number of epochs train_loader (nn.DataLoader) -- train set loader optimizer (torch.optim optimizer) -- the optimizer to use criterion -- loss function to use scheduler -- learning rate scheduler device (torch.device) -- device to use start_epoch (int) -- starting epoch (useful for resuming training) snapshot_every (int) -- frequency of snapshots (in epochs) test_loader (optional, nn.DataLoader) -- test set loader train_acc (bool) -- whether to report accuracy on the train set """ converged = True # used to kill models that plateau top1_prec = 0. if snapshot_every < 1: snapshot_every = end_epoch start_loss = 0. for epoch in range(start_epoch, end_epoch): # training loss avg_loss = 0. epoch_loss = 0. for batch_idx, (x, target) in enumerate(train_loader): optimizer.zero_grad() x, target = x.to(device, non_blocking=True), target.to(device, non_blocking=True) out = model(x) loss = criterion(out, target) avg_loss = avg_loss * 0.99 + loss.item() * 0.01 epoch_loss += loss.item() loss.backward() optimizer.step() # report training loss if ((batch_idx + 1) % 100 == 0) or ((batch_idx + 1) == len(train_loader)): utils.print_train_loss(epoch, avg_loss, batch_idx, len(train_loader)) # report training loss over epoch epoch_loss /= len(train_loader) utils.print_train_loss_epoch(epoch, epoch_loss) writer.add_scalar('Loss/train', avg_loss, epoch) if scheduler is not None: scheduler.step() writer.add_scalar('Lr', scheduler.get_lr()[0], epoch) if (epoch < snapshot_all_until) or ( (epoch + 1) % snapshot_every == 0) or ((epoch + 1) == end_epoch): top1_acc, top5_acc = 0, 0 if val_loader is not None: val_loss, top1_acc, top5_acc = scores.evaluate(model, val_loader, criterion, device, topk=(1, 5)) utils.print_val_loss(epoch, val_loss, top1_acc, top5_acc) model = model.train() writer.add_scalar('Loss/val', val_loss, epoch) writer.add_scalar('Accuracy/val/top1', top1_acc, epoch) writer.add_scalar('Accuracy/val/top5', top5_acc, epoch) # check whether training is stalling if kill_plateaus: if top1_prec == top1_acc: logger.debug( "Prec val accuracy: {}, current val accuracy: {}. Model unlikely to converge. Quitting." .format(top1_prec, top1_acc)) converged = False return model, converged else: top1_prec = top1_acc if train_acc: train_loss, top1_train, top5_train = scores.evaluate( model, train_loader, criterion, device, topk=(1, 5)) utils.print_train_loss_epoch(epoch, train_loss, top1_train, top5_train) model = model.train() writer.add_scalar('Accuracy/train/top1', top1_train, epoch) writer.add_scalar('Accuracy/train/top5', top5_train, epoch) # save snapshot snapshot.save_snapshot(model, optimizer, scheduler, epoch, top1_acc, top5_acc, filename, snapshot_dirname) return model, converged
def main(args): """ Main """ if args.arch is None: print("Available architectures:") print(models.__all__) sys.exit(0) if args.dataset is None: print("Available datasets:") print(data_loader.__all__) sys.exit(0) # set manual seed if required if args.seed is not None: torch.manual_seed(args.seed) results_path, plots_path = prepare_dirs(args) logger = logging.getLogger('reinit_weights') if torch.cuda.is_available() and args.cuda: device = torch.device("cuda:0") cudnn.benchmark = True else: device = torch.device("cpu") classes, in_channels = data_loader.num_classes(args.dataset) if args.subsample_classes > 0: classes = args.subsample_classes if os.path.exists(args.load_from): logger.info("Loading {} from {}.".format(args.arch, args.load_from)) net = snapshot.load_model(args.arch, classes, args.load_from, device, in_channels) else: logger.info("Cannot load trained model from {}: no such file.".format(args.load_from)) sys.exit(1) net = net.to(device) criterion = nn.CrossEntropyLoss().to(device) # load test set _, test_loader, _ = data_loader.load_dataset(args.dataset, args.data_path, args.batch_size, shuffle=True, augmentation=False, num_workers=args.workers, nclasses=args.subsample_classes, class_sample_seed=args.class_sample_seed, upscale=args.upscale, upscale_padding=args.upscale_padding) # evaluate model logger.info("Evaluating trained model on test set.") test_loss, top1_acc, top5_acc = scores.evaluate(net, test_loader, criterion, device, topk=(1,5)) utils.print_val_loss(0, test_loss, top1_acc, top5_acc, 'reinit_weights') results={} if os.path.exists(args.inits_from): logger.info("Loading network weights initializations from {}.".format(args.inits_from)) # get generator with open(args.inits_from, 'r') as fp: for init_file in init_filenames(fp): if os.path.exists(init_file): logger.info("Loading network weights from {}".format(init_file)) net_init = snapshot.load_model(args.arch, classes, init_file, device, in_channels) else: logger.warning("Warning. File not found: {}. Skipping.".format(init_file)) continue splits = os.path.splitext(init_file)[0].split('_') if 'init' in splits: epoch = 0 else: epoch = int(splits[-1]) +1 results = reinit_weights(results, net, net_init, test_loader, criterion, epoch, device) if args.rand: # load random initialization logger.info("Loading random initialization.") random_init = models.load_model(args.arch, classes, pretrained=False, in_channels=in_channels) # randomize weights and compute accuracy results = reinit_weights(results, net, random_init, test_loader, criterion, "rand", device) if os.path.exists(args.load_from): filename = os.path.splitext(os.path.basename(args.load_from))[0] filename = 'reinit_' + filename results = write_metadata(results, args.load_from, args, top1_acc, top5_acc) # save results to json logger.info("Saving results to file.") snapshot.save_results(results, filename, results_path) # plot results logger.info("Plotting results.") plot_heatmaps(results, filename, plots_path)
def main(args): # set up project directories tb_logdir, snapshot_dir = prepare_dirs(args) # get logger logger = logging.getLogger('train') # tensorboard writer writer = get_writer(args, tb_logdir) use_cuda = torch.cuda.is_available() and args.cuda # set manual seed if required if args.seed is not None: torch.manual_seed(args.seed) if use_cuda: torch.cuda.manual_seed_all(args.seed) # check for cuda supports if use_cuda: device = torch.device("cuda:0") cudnn.benchmark = True else: device = torch.device("cpu") # snapshot frequency if args.snapshot_every > 0 and not args.evaluate: logger.info('Saving snapshots to {}'.format(snapshot_dir)) # load model classes, in_channels = data_loader.num_classes(args.dataset) if args.subsample_classes > 0: classes = args.subsample_classes net = models.load_model(args.arch, classes=classes, pretrained=args.pretrained, in_channels=in_channels) if args.pretrained and args.resume_from == '': logger.info('Loading pretrained {} on ImageNet.'.format(args.arch)) else: logger.info('Creating model {}.'.format(args.arch)) if torch.cuda.device_count() > 1: logger.info("Running on {} GPUs".format(torch.cuda.device_count())) net.features = torch.nn.DataParallel(net.features) # move net to device net = net.to(device=device) # get data loader for the specified dataset train_loader, test_loader, val_loader = data_loader.load_dataset( args.dataset, args.data_path, args.batch_size, shuffle=args.shuffle, augmentation=args.augmentation, noise=args.noise, split=args.split, num_workers=args.workers, split_seed=args.split_seed, noise_seed=args.noise_seed, stratified=args.stratified, nclasses=args.subsample_classes, class_sample_seed=args.class_sample_seed, no_normalization=args.unnormalize, upscale=args.upscale, upscale_padding=args.upscale_padding) # define loss criterion = nn.CrossEntropyLoss().to(device) start_epoch = args.start_epoch best_acc1, best_acc5 = 0, 0 # load model from file if os.path.isfile(args.resume_from): # resume training given state dictionary optimizer, scheduler = load_optimizer(args, net) try: net, optimizer, scheduler, start_epoch, best_acc1, best_acc5 = snapshot.load_snapshot( net, optimizer, scheduler, args.resume_from, device) if args.override: override_hyperparams(args, optimizer, scheduler) except KeyError: classes, in_channels = data_loader.num_classes(args.dataset) if args.subsample_classes > 0: classes = args.subsample_classes net = snapshot.load_model(args.arch, classes, args.resume_from, device, in_channels) else: # define optimizer optimizer, scheduler = load_optimizer(args, net) # evaluate model if args.evaluate: val_loss, top1_acc, top5_acc = scores.evaluate(net, test_loader, criterion, device, topk=(1, 5)) utils.print_val_loss(args.epochs, val_loss, top1_acc, top5_acc) writer.add_scalar('Loss/test', val_loss, args.epochs) writer.add_scalar('Accuracy/test/top1', top1_acc, args.epochs) writer.add_scalar('Accuracy/test/top5', top5_acc, args.epochs) writer.close() return if args.evaluate_train: train_loss, top1_acc, top5_acc = scores.evaluate(net, train_loader, criterion, device, topk=(1, 5)) utils.print_train_loss_epoch(args.epochs, train_loss, top1_acc, top5_acc) if best_acc1 * best_acc5 > 0: # if nonzero, print best val accuracy utils.print_val_loss(args.epochs, -1., best_acc1, best_acc5) writer.add_scalar('Loss/train', train_loss, args.epochs) writer.add_scalar('Accuracy/train/top1', top1_acc, args.epochs) writer.add_scalar('Accuracy/train/top5', top5_acc, args.epochs) writer.close() return if args.eval_regularization_loss: regularization_loss = scores.compute_regularization_loss( net, args.weight_decay) utils.print_regularization_loss_epoch(args.epochs, regularization_loss) writer.add_scalar('Regularization loss', regularization_loss, args.epochs) writer.close() return utils.print_model_config(args) if start_epoch == 0: pretrained = 'pretrained_' if args.pretrained else 'init_' filename = args.arch + '_' + pretrained + str(start_epoch) + '.pt' logger.info("Saving model initialization to {}".format(filename)) snapshot.save_model(net, filename, snapshot_dir) # train the model net.train() if val_loader is None and test_loader is not None: val_loader = test_loader logger.warning("Using TEST set to validate model during training!") net, converged = train(net, args.epochs, train_loader, optimizer, criterion, scheduler, device, snapshot_dirname=snapshot_dir, start_epoch=start_epoch, snapshot_every=args.snapshot_every, val_loader=val_loader, kill_plateaus=args.kill_plateaus, best_acc1=best_acc1, writer=writer, snapshot_all_until=args.snapshot_all_until, filename=args.arch, train_acc=args.train_acc) if test_loader is not None: val_loss, top1_acc, top5_acc = scores.evaluate(net, test_loader, criterion, device, topk=(1, 5)) utils.print_val_loss(args.epochs, val_loss, top1_acc, top5_acc) net = net.train() writer.add_scalar('Loss/test', val_loss, args.epochs) writer.add_scalar('Accuracy/test/top1', top1_acc, args.epochs) writer.add_scalar('Accuracy/test/top5', top5_acc, args.epochs) # save final model if converged: pretrained = 'pretrained_' if args.pretrained else '' filename = args.arch + '_' + pretrained + str(args.epochs) + '.pt' snapshot.save_model(net, filename, snapshot_dir) writer.close()