def main(args): # Initialize multi-processing distributed.init_process_group(backend='nccl', init_method='env://') device_id, device = args.local_rank, torch.device(args.local_rank) rank, world_size = distributed.get_rank(), distributed.get_world_size() torch.cuda.set_device(device_id) # Load configuration config = make_config(args) # Experiment Path exp_dir = make_dir(config, args.directory) # Initialize logging if rank == 0: logging.init(exp_dir, "training" if not args.eval else "eval") summary = tensorboard.SummaryWriter(args.directory) else: summary = None body_config = config["body"] optimizer_config = config["optimizer"] # Load data train_dataloader, val_dataloader = make_dataloader(args, config, rank, world_size) # Initialize model if body_config.getboolean("pretrained"): log_debug("Use pre-trained model %s", body_config.get("arch")) else: log_debug("Initialize model to train from scratch %s".body_config.get( "arch")) # Load model model, output_dim = make_model(args, config) print(model) # Resume / Pre_Train if args.resume: assert not args.pre_train, "resume and pre_train are mutually exclusive" log_debug("Loading snapshot from %s", args.resume) snapshot = resume_from_snapshot( model, args.resume, ["body", "local_head_coarse", "local_head_fine"]) elif args.pre_train: assert not args.resume, "resume and pre_train are mutually exclusive" log_debug("Loading pre-trained model from %s", args.pre_train) pre_train_from_snapshots( model, args.pre_train, ["body", "local_head_coarse", "local_head_fine"]) else: #assert not args.eval, "--resume is needed in eval mode" snapshot = None # Init GPU stuff torch.backends.cudnn.benchmark = config["general"].getboolean( "cudnn_benchmark") model = DistributedDataParallel(model.cuda(device), device_ids=[device_id], output_device=device_id, find_unused_parameters=True) # Create optimizer & scheduler optimizer, scheduler, parameters, batch_update, total_epochs = make_optimizer( model, config, epoch_length=len(train_dataloader)) if args.resume: optimizer.load_state_dict(snapshot["state_dict"]["optimizer"]) # Training loop momentum = 1. - 1. / len(train_dataloader) meters = { "loss": AverageMeter((), momentum), "epipolar_loss": AverageMeter((), momentum), "consistency_loss": AverageMeter((), momentum), } if args.resume: start_epoch = snapshot["training_meta"]["epoch"] + 1 best_score = snapshot["training_meta"]["best_score"] global_step = snapshot["training_meta"]["global_step"] for name, meter in meters.items(): meter.load_state_dict(snapshot["state_dict"][name + "_meter"]) del snapshot else: start_epoch = 0 best_score = { "val": 1000.0, "test": 0.0, } global_step = 0 # Optional: evaluation only: if args.eval: log_info("Evaluation epoch %d", start_epoch - 1) test(args, config, model, rank=rank, world_size=world_size, output_dim=output_dim, device=device) log_info("Evaluation Done ..... ") exit(0) for epoch in range(start_epoch, total_epochs): log_info("Starting epoch %d", epoch + 1) if not batch_update: scheduler.step(epoch) score = {} # Run training global_step = train( model, config, train_dataloader, optimizer, scheduler, meters, summary=summary, batch_update=batch_update, log_interval=config["general"].getint("log_interval"), epoch=epoch, num_epochs=total_epochs, global_step=global_step, output_dim=output_dim, world_size=world_size, rank=rank, device=device, loss_weights=optimizer_config.getstruct("loss_weights")) # Save snapshot (only on rank 0) if rank == 0: snapshot_file = path.join(exp_dir, "model_{}.pth.tar".format(epoch)) log_debug("Saving snapshot to %s", snapshot_file) meters_out_dict = { k + "_meter": v.state_dict() for k, v in meters.items() } save_snapshot( snapshot_file, config, epoch, 0, best_score, global_step, body=model.module.body.state_dict(), local_head_coarse=model.module.local_head_coarse.state_dict(), local_head_fine=model.module.local_head_fine.state_dict(), optimizer=optimizer.state_dict(), **meters_out_dict) # Run validation if (epoch + 1) % config["general"].getint("val_interval") == 0: log_info("Validating epoch %d", epoch + 1) score['val'] = validate( model, config, val_dataloader, summary=summary, batch_update=batch_update, log_interval=config["general"].getint("log_interval"), epoch=epoch, num_epochs=total_epochs, global_step=global_step, output_dim=output_dim, world_size=world_size, rank=rank, device=device, loss_weights=optimizer_config.getstruct("loss_weights")) # Run Test if (epoch + 1) % config["general"].getint("test_interval") == 0: log_info("Testing epoch %d", epoch + 1) score['test'] = test(args, config, model, rank=rank, world_size=world_size, output_dim=output_dim, device=device) # Update the score on the last saved snapshot if rank == 0: snapshot = torch.load(snapshot_file, map_location="cpu") snapshot["training_meta"]["last_score"] = score torch.save(snapshot, snapshot_file) del snapshot if score['test'] > best_score['test']: best_score = score if rank == 0: shutil.copy(snapshot_file, path.join(exp_dir, "test_model_best.pth.tar"))
def validate(model, config, dataloader, **varargs): # create tuples for validation data_config = config["dataloader"] # Switch to eval mode model.eval() dataloader.batch_sampler.set_epoch(varargs["epoch"]) # dataloader.dataset.update() loss_weights = varargs["loss_weights"] loss_meter = AverageMeter(()) data_time_meter = AverageMeter(()) batch_time_meter = AverageMeter(()) data_time = time.time() for it, batch in enumerate(dataloader): with torch.no_grad(): #Upload batch batch = { k: batch[k].cuda(device=varargs["device"], non_blocking=True) for k in NETWORK_TRAIN_INPUTS } data_time_meter.update(torch.tensor(time.time() - data_time)) batch_time = time.time() # Run network losses, _ = model(**batch, do_loss=True, do_prediction=True, do_augmentation=True) losses = OrderedDict((k, v.mean()) for k, v in losses.items()) losses = all_reduce_losses(losses) loss = sum(w * l for w, l in zip(loss_weights, losses.values())) # Update meters loss_meter.update(loss.cpu()) batch_time_meter.update(torch.tensor(time.time() - batch_time)) del loss, losses, batch # Log batch if varargs["summary"] is not None and ( it + 1) % varargs["log_interval"] == 0: logging.iteration( None, "val", varargs["global_step"], varargs["epoch"] + 1, varargs["num_epochs"], it + 1, len(dataloader), OrderedDict([("loss", loss_meter), ("data_time", data_time_meter), ("batch_time", batch_time_meter)])) data_time = time.time() return loss_meter.mean
def validate(model, config, dataloader, **varargs): # create tuples for validation data_config = config["dataloader"] distributed.barrier() avg_neg_distance = dataloader.dataset.create_epoch_tuples( model, log_info, log_debug, output_dim=varargs["output_dim"], world_size=varargs["world_size"], rank=varargs["rank"], device=varargs["device"], data_config=data_config) distributed.barrier() # Switch to eval mode model.eval() dataloader.batch_sampler.set_epoch(varargs["epoch"]) loss_weights = varargs["loss_weights"] loss_meter = AverageMeter(()) data_time_meter = AverageMeter(()) batch_time_meter = AverageMeter(()) data_time = time.time() for it, batch in enumerate(dataloader): with torch.no_grad(): # Upload batch for k in NETWORK_INPUTS: if isinstance(batch[k][0], PackedSequence): batch[k] = [ item.cuda(device=varargs["device"], non_blocking=True) for item in batch[k] ] else: batch[k] = batch[k].cuda(device=varargs["device"], non_blocking=True) data_time_meter.update(torch.tensor(time.time() - data_time)) batch_time = time.time() # Run network losses, _ = model(**batch, do_loss=True, do_prediction=True) losses = OrderedDict((k, v.mean()) for k, v in losses.items()) losses = all_reduce_losses(losses) loss = sum(w * l for w, l in zip(loss_weights, losses.values())) # Update meters loss_meter.update(loss.cpu()) batch_time_meter.update(torch.tensor(time.time() - batch_time)) del loss, losses, batch # Log batch if varargs["summary"] is not None and ( it + 1) % varargs["log_interval"] == 0: logging.iteration( None, "val", varargs["global_step"], varargs["epoch"] + 1, varargs["num_epochs"], it + 1, len(dataloader), OrderedDict([("loss", loss_meter), ("data_time", data_time_meter), ("batch_time", batch_time_meter)])) data_time = time.time() return loss_meter.mean
def train(model, config, dataloader, optimizer, scheduler, meters, **varargs): # Create tuples for training data_config = config["dataloader"] # Switch to train mode model.train() dataloader.batch_sampler.set_epoch(varargs["epoch"]) # dataloader.dataset.update() optimizer.zero_grad() global_step = varargs["global_step"] loss_weights = varargs["loss_weights"] data_time_meter = AverageMeter((), meters["loss"].momentum) batch_time_meter = AverageMeter((), meters["loss"].momentum) data_time = time.time() torch.autograd.set_detect_anomaly(True) for it, batch in enumerate(dataloader): #Upload batch batch = { k: batch[k].cuda(device=varargs["device"], non_blocking=True) for k in NETWORK_TRAIN_INPUTS } # Measure data loading time data_time_meter.update(torch.tensor(time.time() - data_time)) # Update scheduler global_step += 1 if varargs["batch_update"]: scheduler.step(global_step) batch_time = time.time() # Run network optimizer.zero_grad() losses, _ = model(**batch, do_loss=True, do_augmentaton=True) distributed.barrier() losses = OrderedDict((k, v.mean()) for k, v in losses.items()) losses["loss"] = sum(w * l for w, l in zip(loss_weights, losses.values())) losses["loss"].backward() optimizer.step() # Gather from all workers losses = all_reduce_losses(losses) # Update meters with torch.no_grad(): for loss_name, loss_value in losses.items(): meters[loss_name].update(loss_value.cpu()) if torch.isnan(loss_value).any(): input() batch_time_meter.update(torch.tensor(time.time() - batch_time)) # Clean-up del batch, losses # Log if varargs["summary"] is not None and ( it + 1) % varargs["log_interval"] == 0: logging.iteration( varargs["summary"], "train", global_step, varargs["epoch"] + 1, varargs["num_epochs"], it + 1, len(dataloader), OrderedDict([("lr_body", scheduler.get_lr()[0] * 1e6), ("loss", meters["loss"]), ("epipolar_loss", meters["epipolar_loss"]), ("consistency_loss", meters["consistency_loss"]), ("data_time", data_time_meter), ("batch_time", batch_time_meter)])) data_time = time.time() return global_step
def train(model, config, dataloader, optimizer, scheduler, meters, **varargs): # Create tuples for training data_config = config["dataloader"] distributed.barrier() avg_neg_distance = dataloader.dataset.create_epoch_tuples( model, log_info, log_debug, output_dim=varargs["output_dim"], world_size=varargs["world_size"], rank=varargs["rank"], device=varargs["device"], data_config=data_config) distributed.barrier() # switch to train mode model.train() dataloader.batch_sampler.set_epoch(varargs["epoch"]) optimizer.zero_grad() global_step = varargs["global_step"] loss_weights = varargs["loss_weights"] data_time_meter = AverageMeter((), meters["loss"].momentum) batch_time_meter = AverageMeter((), meters["loss"].momentum) data_time = time.time() for it, batch in enumerate(dataloader): # Upload batch for k in NETWORK_INPUTS: if isinstance(batch[k][0], PackedSequence): batch[k] = [ item.cuda(device=varargs["device"], non_blocking=True) for item in batch[k] ] else: batch[k] = batch[k].cuda(device=varargs["device"], non_blocking=True) # Measure data loading time data_time_meter.update(torch.tensor(time.time() - data_time)) # Update scheduler global_step += 1 if varargs["batch_update"]: scheduler.step(global_step) batch_time = time.time() # Run network losses, _ = model(**batch, do_loss=True, do_augmentaton=True) distributed.barrier() losses = OrderedDict((k, v.mean()) for k, v in losses.items()) losses["loss"] = sum(w * l for w, l in zip(loss_weights, losses.values())) losses["loss"].backward() optimizer.step() optimizer.zero_grad() if (it + 1) % 5 == 0: optimizer.step() optimizer.zero_grad() # Gather from all workers losses = all_reduce_losses(losses) # Update meters with torch.no_grad(): for loss_name, loss_value in losses.items(): meters[loss_name].update(loss_value.cpu()) batch_time_meter.update(torch.tensor(time.time() - batch_time)) # Clean-up del batch, losses # Log if varargs["summary"] is not None and ( it + 1) % varargs["log_interval"] == 0: logging.iteration( varargs["summary"], "train", global_step, varargs["epoch"] + 1, varargs["num_epochs"], it + 1, len(dataloader), OrderedDict([("lr_body", scheduler.get_lr()[0] * 1e6), ("lr_ret", scheduler.get_lr()[1] * 1e6), ("loss", meters["loss"]), ("ret_loss", meters["ret_loss"]), ("data_time", data_time_meter), ("batch_time", batch_time_meter)])) data_time = time.time() return global_step