json_fn = os.path.join(base_outdir, "param.json") check_if_done(json_fn) args.machine = os.uname()[1] save_dic_to_json(args.__dict__, json_fn) train_img_shape = tuple([int(x) for x in train_args.train_img_shape]) test_img_shape = tuple([int(x) for x in args.test_img_shape]) # TODO if "normalize_way" in train_args.__dict__.keys(): img_transform = get_img_transform(img_shape=train_img_shape, normalize_way=train_args.normalize_way) else: img_transform = get_img_transform(img_shape=train_img_shape) if "background_id" in train_args.__dict__.keys(): label_transform = get_lbl_transform(img_shape=train_img_shape, n_class=train_args.n_class, background_id=train_args.background_id) else: label_transform = get_lbl_transform(img_shape=train_img_shape, n_class=train_args.n_class) tgt_dataset = get_dataset(dataset_name=args.tgt_dataset, split=args.split, img_transform=img_transform, label_transform=label_transform, test=True, input_ch=train_args.input_ch) target_loader = data.DataLoader(tgt_dataset, batch_size=1, pin_memory=True) G_3ch, G_1ch, F1, F2 = get_models(net_name=train_args.net, res=train_args.res, input_ch=train_args.input_ch, n_class=train_args.n_class, method=detailed_method, is_data_parallel=train_args.is_data_parallel) G_3ch.load_state_dict(checkpoint['g_3ch_state_dict']) G_1ch.load_state_dict(checkpoint['g_1ch_state_dict']) F1.load_state_dict(checkpoint['f1_state_dict'])
check_if_done(json_fn) save_dic_to_json(args.__dict__, json_fn) train_img_shape = tuple([int(x) for x in args.train_img_shape]) use_crop = True if args.crop_size > 0 else False joint_transform = get_joint_transform( crop_size=args.crop_size, rotate_angle=args.rotate_angle) if use_crop else None img_transform = get_img_transform(img_shape=train_img_shape, normalize_way=args.normalize_way, use_crop=use_crop) label_transform = get_lbl_transform(img_shape=train_img_shape, n_class=args.n_class, background_id=args.background_id, use_crop=use_crop) src_dataset = get_dataset(dataset_name=args.src_dataset, split=args.src_split, img_transform=img_transform, label_transform=label_transform, test=False, input_ch=args.input_ch) tgt_dataset = get_dataset(dataset_name=args.tgt_dataset, split=args.tgt_split, img_transform=img_transform, label_transform=label_transform, test=False, input_ch=args.input_ch)