예제 #1
0
    img_transform = get_img_transform(img_shape=train_img_shape, normalize_way=train_args.normalize_way)
else:
    img_transform = get_img_transform(img_shape=train_img_shape)

if "background_id" in train_args.__dict__.keys():
    label_transform = get_lbl_transform(img_shape=train_img_shape, n_class=train_args.n_class,
                                        background_id=train_args.background_id)
else:
    label_transform = get_lbl_transform(img_shape=train_img_shape, n_class=train_args.n_class)

tgt_dataset = get_dataset(dataset_name=args.tgt_dataset, split=args.split, img_transform=img_transform,
                          label_transform=label_transform, test=True, input_ch=train_args.input_ch)
target_loader = data.DataLoader(tgt_dataset, batch_size=1, pin_memory=True)

G_3ch, G_1ch, F1, F2 = get_models(net_name=train_args.net, res=train_args.res, input_ch=train_args.input_ch,
                                  n_class=train_args.n_class,
                                  method=detailed_method, is_data_parallel=train_args.is_data_parallel)

G_3ch.load_state_dict(checkpoint['g_3ch_state_dict'])
G_1ch.load_state_dict(checkpoint['g_1ch_state_dict'])

F1.load_state_dict(checkpoint['f1_state_dict'])

if args.use_f2:
    F2.load_state_dict(checkpoint['f2_state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
      .format(args.trained_checkpoint, checkpoint['epoch']))

G_3ch.eval()
G_1ch.eval()
F1.eval()
예제 #2
0
    def __init__(self, args):

        # Load parameters of the network from training arguments.
        print("=> loading checkpoint '{}'".format(args.trained_checkpoint))
        assert os.path.exists(args.trained_checkpoint), args.trained_checkpoint
        checkpoint = torch.load(args.trained_checkpoint)
        train_args = checkpoint["args"]
        print("----- train args ------")
        pprint(train_args.__dict__, indent=4)
        print("-" * 50)
        self.input_ch = train_args.input_ch
        self.image_shape = tuple([int(x) for x in train_args.train_img_shape])
        print("=> loaded checkpoint '{}'".format(args.trained_checkpoint))

        self.img_transform = Compose([
            Image.fromarray,
            Scale(self.image_shape, Image.BILINEAR),
            ToTensor(),
            Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        try:
            self.G, self.F1, self.F2 = get_models(
                net_name=train_args.net,
                res=train_args.res,
                input_ch=train_args.input_ch,
                n_class=train_args.n_class,
                method=train_args.method,
                is_data_parallel=train_args.is_data_parallel,
                use_ae=args.use_ae)
        except AttributeError:
            self.G, self.F1, self.F2 = get_models(net_name=train_args.net,
                                                  res=train_args.res,
                                                  input_ch=train_args.input_ch,
                                                  n_class=train_args.n_class,
                                                  method="MCD",
                                                  is_data_parallel=False)

        self.G.load_state_dict(checkpoint['g_state_dict'])
        self.F1.load_state_dict(checkpoint['f1_state_dict'])

        if args.use_f2:
            self.F2.load_state_dict(checkpoint['f2_state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.trained_checkpoint, checkpoint['epoch']))

        self.G.eval()
        self.F1.eval()
        self.F2.eval()

        if torch.cuda.is_available():
            self.G.cuda()
            self.F1.cuda()
            self.F2.cuda()

        self.use_f2 = args.use_f2

        self.add_bg_loss = train_args.add_bg_loss
        self.n_class = train_args.n_class
        print('=> n_class = %d, add_bg_loss = %s' %
              (self.n_class, self.add_bg_loss))

def get_model_name_from_path(path):
    return path.split(os.path.sep)[-1].replace(".tar", "")


args.savename = get_model_name_from_path(
    args.checkpoint) + "AND" + get_model_name_from_path(args.extra_checkpoint)

print("savename is %s " % (args.savename))

checkpoint = torch.load(args.checkpoint)
model_g_3ch, model_g_1ch, model_f1, model_f2 = get_models(
    net_name=args.net,
    res=args.res,
    input_ch=args.input_ch,
    n_class=args.n_class,
    method=detailed_method,
    is_data_parallel=args.is_data_parallel)
optimizer_g = get_optimizer(list(model_g_3ch.parameters()) +
                            list(model_g_1ch.parameters()),
                            lr=args.lr,
                            opt=args.opt,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
optimizer_f = get_optimizer(list(model_f1.parameters()) +
                            list(model_f2.parameters()),
                            lr=args.lr,
                            opt=args.opt,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
예제 #4
0
    indir, infn = os.path.split(args.resume)

    old_savename = args.savename
    args.savename = infn.split("-")[0]
    print("savename is %s (original savename %s was overwritten)" %
          (args.savename, old_savename))

    checkpoint = torch.load(args.resume)
    start_epoch = checkpoint["epoch"]
    # ---------- Replace Args!!! ----------- #
    args = checkpoint['args']
    # -------------------------------------- #
    model_g, model_f1, model_f2 = get_models(
        net_name=args.net,
        res=args.res,
        input_ch=args.input_ch,
        n_class=args.n_class,
        method=args.method,
        is_data_parallel=args.is_data_parallel)
    optimizer_g = get_optimizer(model_g.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                opt=args.opt,
                                weight_decay=args.weight_decay)
    optimizer_f = get_optimizer(list(model_f1.parameters()) +
                                list(model_f2.parameters()),
                                lr=args.lr,
                                opt=args.opt,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
])
label_transform = Compose([Scale(train_img_shape, Image.BILINEAR), ToTensor()])

tgt_dataset = get_dataset(dataset_name=args.tgt_dataset,
                          split=args.split,
                          img_transform=img_transform,
                          label_transform=label_transform,
                          test=True,
                          input_ch=train_args.input_ch)
target_loader = data.DataLoader(tgt_dataset, batch_size=1, pin_memory=True)

try:
    G, F1, F2 = get_models(net_name=train_args.net,
                           res=train_args.res,
                           input_ch=train_args.input_ch,
                           n_class=train_args.n_class,
                           method=train_args.method,
                           is_data_parallel=train_args.is_data_parallel,
                           use_ae=args.use_ae)
except AttributeError:
    G, F1, F2 = get_models(net_name=train_args.net,
                           res=train_args.res,
                           input_ch=train_args.input_ch,
                           n_class=train_args.n_class,
                           method="MCD",
                           is_data_parallel=False)

G.load_state_dict(checkpoint['g_state_dict'])
F1.load_state_dict(checkpoint['f1_state_dict'])

if args.use_f2:
예제 #6
0
args = parser.parse_args()
args = add_additional_params_to_args(args)
args = fix_img_shape_args(args)

FORMAT = '[%(filename)s:%(lineno)s - %(funcName)s() %(levelname)s]: %(message)s'
logging.basicConfig(level=args.logging, format=FORMAT)

if not os.path.exists(args.trained_checkpoint):
    raise OSError("%s does not exist!" % args.trained_checkpoint)

checkpoint = torch.load(args.trained_checkpoint)
train_args = checkpoint['args']  # Load args!

model_g, model_f1, model_f2 = get_models(net_name=train_args.net,
                                         res=train_args.res,
                                         input_ch=train_args.input_ch,
                                         n_class=train_args.n_class,
                                         yaw_loss=train_args.yaw_loss)
model_g.load_state_dict(checkpoint['g_state_dict'])
model_f1.load_state_dict(checkpoint['f1_state_dict'])
model_g.eval()
model_f1.eval()
if torch.cuda.is_available():
    model_g.cuda()
    model_f1.cuda()

print("----- train args ------")
pprint(checkpoint["args"].__dict__, indent=4)
print("-" * 50)
args.train_img_shape = checkpoint["args"].train_img_shape
print("=> loaded checkpoint '{}'".format(args.trained_checkpoint))
예제 #7
0
                              img_transform=img_transform,
                              label_transform=label_transform,
                              test=False,
                              input_ch=args.input_ch,
                              keys_dict={'image': 'T_image'})

concat_dataset = ConcatDataset([src_dataset, tgt_dataset])
train_loader = torch.utils.data.DataLoader(concat_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           pin_memory=True)

model_g, model_f1, model_f2 = get_models(
    net_name=args.net,
    res=args.res,
    input_ch=args.input_ch,
    n_class=args.n_class,
    is_data_parallel=args.is_data_parallel,
    yaw_loss=args.yaw_loss)

optimizer_g = get_optimizer(model_g.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            opt=args.opt,
                            weight_decay=args.weight_decay)
optimizer_f = get_optimizer(list(model_f1.parameters()) +
                            list(model_f2.parameters()),
                            lr=args.lr,
                            momentum=args.momentum,
                            opt=args.opt,
                            weight_decay=args.weight_decay)
예제 #8
0
                          test=True,
                          input_ch=train_args.input_ch,
                          keys_dict={
                              'image': 'image',
                              'image_original': 'image_original',
                              'url': 'url'
                          })
target_loader = data.DataLoader(tgt_dataset,
                                batch_size=10,
                                pin_memory=True,
                                shuffle=False)

G, F1, F2 = get_models(net_name=train_args.net,
                       res=train_args.res,
                       input_ch=train_args.input_ch,
                       n_class=train_args.n_class,
                       method=train_args.method,
                       is_data_parallel=train_args.is_data_parallel,
                       yaw_loss=train_args.yaw_loss)

G.load_state_dict(checkpoint['g_state_dict'])
F1.load_state_dict(checkpoint['f1_state_dict'])

if args.use_f2:
    F2.load_state_dict(checkpoint['f2_state_dict'])
print("=> loaded checkpoint '{}' (epoch {})".format(args.trained_checkpoint,
                                                    checkpoint['epoch']))

G.eval()
F1.eval()
F2.eval()