def worker(cfg, gpu_id, start_idx, end_idx, result_queue): torch.cuda.set_device(gpu_id) # Dataset and Loader dataset_val = ValDataset(cfg.DATASET.root_dataset, cfg.DATASET.list_val, cfg.DATASET, start_idx=start_idx, end_idx=end_idx) loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=cfg.VAL.batch_size, shuffle=False, collate_fn=user_scattered_collate, num_workers=2) # Network Builders net_enc_query = ModelBuilder.build_encoder( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_query) net_enc_memory = ModelBuilder.build_encoder_memory( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_memory, num_class=cfg.DATASET.num_class) net_att_query = ModelBuilder.build_encoder( arch='attention', fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_query) net_att_memory = ModelBuilder.build_encoder( arch='attention', fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_memory) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), fc_dim=cfg.MODEL.fc_dim, num_class=cfg.DATASET.num_class, weights=cfg.MODEL.weights_decoder, use_softmax=True) crit = nn.NLLLoss(ignore_index=-1) segmentation_module = SegmentationAttentionModule( net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, crit, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar) segmentation_module.cuda() # Main loop evaluate(segmentation_module, loader_val, cfg, gpu_id, result_queue)
def main(cfg, gpus): # Network Builders net_enc_query = ModelBuilder.build_encoder( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_query) net_enc_memory = ModelBuilder.build_encoder_memory( arch=cfg.MODEL.arch_encoder.lower(), fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_enc_memory, num_class=cfg.DATASET.num_class) net_att_query = ModelBuilder.build_encoder( arch='attention', fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_query) net_att_memory = ModelBuilder.build_encoder( arch='attention', fc_dim=cfg.MODEL.fc_dim, weights=cfg.MODEL.weights_att_memory) net_decoder = ModelBuilder.build_decoder( arch=cfg.MODEL.arch_decoder.lower(), fc_dim=cfg.MODEL.fc_dim, num_class=cfg.DATASET.num_class, weights=cfg.MODEL.weights_decoder) crit = nn.NLLLoss(ignore_index=-1) if cfg.MODEL.arch_decoder.endswith('deepsup'): segmentation_module = SegmentationAttentionModule( net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, crit, cfg.TRAIN.deep_sup_scale, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar) else: segmentation_module = SegmentationAttentionModule( net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, crit, normalize_key=cfg.MODEL.normalize_key, p_scalar=cfg.MODEL.p_scalar) # Dataset and Loader dataset_train = TrainDataset(cfg.DATASET.root_dataset, cfg.DATASET.list_train, cfg.DATASET, cfg.DATASET.ref_path, cfg.DATASET.ref_start, cfg.DATASET.ref_end, batch_per_gpu=cfg.TRAIN.batch_size_per_gpu) loader_train = torch.utils.data.DataLoader( dataset_train, batch_size=len(gpus), # we have modified data_parallel shuffle=False, # we do not use this param collate_fn=user_scattered_collate, num_workers=cfg.TRAIN.workers, drop_last=True, pin_memory=True) print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters)) # create loader iterator iterator_train = iter(loader_train) # load nets into gpu '''if len(gpus) > 1: segmentation_module = UserScatteredDataParallel( segmentation_module, device_ids=gpus) # For sync bn patch_replication_callback(segmentation_module)''' segmentation_module = UserScatteredDataParallel(segmentation_module, device_ids=gpus) # For sync bn patch_replication_callback(segmentation_module) segmentation_module.cuda() # Set up optimizers nets = (net_enc_query, net_enc_memory, net_att_query, net_att_memory, net_decoder, crit) optimizers = create_optimizers(nets, cfg) # Main loop history = {'train': {'epoch': [], 'loss': [], 'acc': []}} for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch): train(segmentation_module, iterator_train, optimizers, history, epoch + 1, cfg) # checkpointing if (epoch + 1) % cfg.TRAIN.save_freq == 0: checkpoint(nets, history, cfg, epoch + 1) print('Training Done!')