def main(): os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' # SET MODEL u_net = UNet([32, 64, 128, 256, 512, 1024, 2048], K, None, verbose=False, useBN=True) if not os.path.exists(ROOT_DIR): raise Exception('Directory does not exist') state_dict = torch.load(TEST_UNET_WEIGHTS_PATH, map_location=lambda storage, loc: storage) if 'checkpoint' in TEST_UNET_WEIGHTS_PATH: state_dict = state_dict['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k.replace('model.', '') new_state_dict[name] = v u_net.load_state_dict(new_state_dict, strict=True) model = Wrapper(u_net, main_device=MAIN_DEVICE) work = SpecChannelUnet(model, ROOT_DIR, PRETRAINED, main_device=MAIN_DEVICE, trackgrad=TRACKGRAD) work.model_version = 'UNIT_WEIGHTED_TESTING' work.train()
def main(): os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' # SET MODEL u_net = UNet([32, 64, 128, 256, 512, 1024, 2048], K, None, verbose=False, useBN=True, dropout=DROPOUT) model = Wrapper(u_net, main_device=MAIN_DEVICE) if not os.path.exists(ROOT_DIR): raise Exception('Directory does not exist') work = DWA(model, ROOT_DIR, PRETRAINED, main_device=MAIN_DEVICE, trackgrad=TRACKGRAD) work.model_version = 'DWA' work.train()