def load_model(model_dir, device): """ Loads the model from file """ if isinstance(device, str): device = torch.device(device) fname = os.path.join(model_dir, 'model.pth.tar') if 'pth.tar' not in model_dir else model_dir checkpoint = torch.load(fname, map_location=str(device)) model = create_model(checkpoint['args'], PointOfInterestVoxelizedDataset, device) model.load_state_dict(checkpoint['state_dict']) return model, checkpoint['args']
def resume(args, dataset, device): """ Loads model and optimizer state from a previous checkpoint. """ print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume, map_location=str(device)) model = create_model(checkpoint['args'], dataset, device) model.load_state_dict(checkpoint['state_dict']) optimizer = create_optimizer(args, model) optimizer.load_state_dict(checkpoint['optimizer']) args.start_epoch = checkpoint['epoch'] scheduler = MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_decay) scheduler.load_state_dict(checkpoint['scheduler']) return model, optimizer, scheduler
def main(): args = get_cli_args() print('Will save to ' + args.output_dir) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) with open(os.path.join(args.output_dir, 'cmdline.txt'), 'w') as f: f.write(" ".join([ "'" + a + "'" if (len(a) == 0 or a[0] != '-') else a for a in sys.argv ])) set_seed(args.seed) device = torch.device(args.device) writer = SummaryWriter(args.output_dir) train_dataset, test_dataset = create_tough_dataset( args, fold_nr=args.cvfold, n_folds=args.num_folds, seed=args.seed, exclude_Vertex_from_train=args.db_exclude_vertex, exclude_Prospeccts_from_train=args.db_exclude_prospeccts) logger.info('Train set size: %d, test set size: %d', len(train_dataset), len(test_dataset)) # Create model and optimizer (or resume pre-existing) if args.resume != '': if args.resume == 'RESUME': args.resume = args.output_dir + '/model.pth.tar' model, optimizer, scheduler = resume(args, train_dataset, device) else: model = create_model(args, train_dataset, device) if args.input_normalization: model.set_input_scaler( estimate_scaler(args, train_dataset, nsamples=200)) optimizer = create_optimizer(args, model) scheduler = MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_decay) ############ def train(): model.train() loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size // args.batch_parts, num_workers=args.nworkers, shuffle=True, drop_last=True, worker_init_fn=set_worker_seed) if logging.getLogger().getEffectiveLevel() > logging.DEBUG: loader = tqdm(loader, ncols=100) loss_buffer, loss_stabil_buffer, pos_dist_buffer, neg_dist_buffer = [], [], [], [] t0 = time.time() for bidx, batch in enumerate(loader): if 0 < args.max_train_samples < bidx * args.batch_size // args.batch_parts: break t_loader = 1000 * (time.time() - t0) inputs = batch['inputs'].to( device) # dimensions: batch_size x (4 or 2) x 24 x 24 x 24 targets = batch['targets'].to(device) if bidx % args.batch_parts == 0: optimizer.zero_grad() t0 = time.time() outputs = model(inputs.view(-1, *inputs.shape[2:])) outputs = outputs.view(*inputs.shape[:2], -1) loss_joint, loss_match, loss_stabil, pos_dist, neg_dist = compute_loss( args, outputs, targets, True) loss_joint.backward() if bidx % args.batch_parts == args.batch_parts - 1: if args.batch_parts > 1: for p in model.parameters(): p.grad.data.div_(args.batch_parts) optimizer.step() t_trainer = 1000 * (time.time() - t0) loss_buffer.append(loss_match.item()) loss_stabil_buffer.append(loss_stabil.item( ) if isinstance(loss_stabil, torch.Tensor) else loss_stabil) pos_dist_buffer.extend(pos_dist.cpu().numpy().tolist()) neg_dist_buffer.extend(neg_dist.cpu().numpy().tolist()) logger.debug( 'Batch loss %f, Loader time %f ms, Trainer time %f ms.', loss_buffer[-1], t_loader, t_trainer) t0 = time.time() ret = { 'loss': np.mean(loss_buffer), 'loss_stabil': np.mean(loss_stabil_buffer), 'pos_dist': np.mean(pos_dist_buffer), 'neg_dist': np.mean(neg_dist_buffer) } return ret ############ def test(): model.eval() loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size // args.batch_parts, num_workers=args.nworkers, worker_init_fn=set_worker_seed) if logging.getLogger().getEffectiveLevel() > logging.DEBUG: loader = tqdm(loader, ncols=100) loss_buffer, loss_stabil_buffer, pos_dist_buffer, neg_dist_buffer = [], [], [], [] with torch.no_grad(): for bidx, batch in enumerate(loader): if 0 < args.max_test_samples < bidx * args.batch_size // args.batch_parts: break inputs = batch['inputs'].to(device) targets = batch['targets'].to(device) outputs = model(inputs.view(-1, *inputs.shape[2:])) outputs = outputs.view(*inputs.shape[:2], -1) loss_joint, loss_match, loss_stabil, pos_dist, neg_dist = compute_loss( args, outputs, targets, False) loss_buffer.append(loss_match.item()) loss_stabil_buffer.append(loss_stabil.item( ) if isinstance(loss_stabil, torch.Tensor) else loss_stabil) pos_dist_buffer.extend(pos_dist.cpu().numpy().tolist()) neg_dist_buffer.extend(neg_dist.cpu().numpy().tolist()) return { 'loss': np.mean(loss_buffer), 'loss_stabil': np.mean(loss_stabil_buffer), 'pos_dist': np.mean(pos_dist_buffer), 'neg_dist': np.mean(neg_dist_buffer) } ############ # Training loop for epoch in range(args.start_epoch, args.epochs): print(f'Epoch {epoch}/{args.epochs} ({args.output_dir}):') scheduler.step() train_stats = train() for k, v in train_stats.items(): writer.add_scalar('train/' + k, v, epoch) print( f"-> Train distances: p {train_stats['pos_dist']}, n {train_stats['neg_dist']}, \tLoss: {train_stats['loss']}" ) if (epoch + 1) % args.test_nth_epoch == 0 or epoch + 1 == args.epochs: test_stats = test() for k, v in test_stats.items(): writer.add_scalar('test/' + k, v, epoch) print( f"-> Test distances: p {test_stats['pos_dist']}, n {test_stats['neg_dist']}, \tLoss: {test_stats['loss']}" ) torch.save( { 'epoch': epoch + 1, 'args': args, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict() }, os.path.join(args.output_dir, 'model.pth.tar')) if math.isnan(train_stats['loss']): break