示例#1
0
    def build_optimizer(self):
        """
        Create optimizer
        """
        if self.opt.train_only_encoder:
            # To train a resnet image encoder with a pre-trained atlasnet decoder.
            yellow_print("only train the Encoder")
            self.optimizer = optim.Adam(
                self.network.module.encoder.parameters(), lr=self.opt.lrate)
        else:
            self.optimizer = optim.Adam(self.network.module.parameters(),
                                        lr=self.opt.lrate)

        if self.opt.reload_optimizer_path != "":
            try:
                self.optimizer.load_state_dict(
                    torch.load(self.opt.reload_optimizer_path,
                               map_location='cuda:0'))
                # yellow_print(f"Reloaded optimizer {self.opt.reload_optimizer_path}")
            except:
                yellow_print(
                    f"Failed to reload optimizer {self.opt.reload_optimizer_path}"
                )

        # Set policy for warm-up if you use multiple GPUs
        self.next_learning_rates = []
        if len(self.opt.multi_gpu) > 1:
            self.next_learning_rates = np.linspace(
                self.opt.lrate, self.opt.lrate * len(self.opt.multi_gpu),
                5).tolist()
            self.next_learning_rates.reverse()
示例#2
0
    def reload_network(self):
        """
        Reload entire model or only decoder (atlasnet) depending on the options
        :return:
        """
        if self.opt.reload_model_path != "":
            yellow_print(
                f"Network weights loaded from  {self.opt.reload_model_path}!")
            # print(self.network.state_dict().keys())
            # print(torch.load(self.opt.reload_model_path).keys())
            self.network.module.load_state_dict(
                torch.load(self.opt.reload_model_path, map_location='cuda:0'))

        elif self.opt.reload_decoder_path != "":
            opt = deepcopy(self.opt)
            opt.SVR = False
            network = EncoderDecoder(opt)
            network = nn.DataParallel(network, device_ids=opt.multi_gpu)
            network.module.load_state_dict(
                torch.load(opt.reload_decoder_path, map_location='cuda:0'))
            self.network.module.decoder = network.module.decoder
            yellow_print(
                f"Network Decoder weights loaded from  {self.opt.reload_decoder_path}!"
            )

        else:
            yellow_print("No network weights to reload!")
示例#3
0
    def reload_network(self):
        """
        Reload entire model or only decoder (atlasnet) depending on the options
        :return:
        """
        #self.opt.reload_model_path="network.pth"
        if self.opt.reload_model_path != "":
            yellow_print(
                f"Network weights loaded from  {self.opt.reload_model_path}!")
            # print(self.network.state_dict().keys())
            # print(torch.load(self.opt.reload_model_path).keys())

            state_dict = torch.load(self.opt.reload_model_path,
                                    map_location='cuda:0')
            # create new OrderedDict that does not contain `module.`
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v

            self.network.module.load_state_dict(new_state_dict)

        elif self.opt.reload_decoder_path != "":
            opt = deepcopy(self.opt)
            opt.SVR = False
            network = EncoderDecoder(opt)
            network = nn.DataParallel(network, device_ids=opt.multi_gpu)
            network.module.load_state_dict(
                torch.load(opt.reload_decoder_path, map_location='cuda:0'))
            self.network.module.decoder = network.module.decoder
            yellow_print(
                f"Network Decoder weights loaded from  {self.opt.reload_decoder_path}!"
            )

        else:
            yellow_print("No network weights to reload!")
示例#4
0
trainer.build_dataset()
trainer.build_network()
trainer.build_optimizer()
trainer.build_losses()
trainer.start_train_time = time.time()

if opt.demo:
    with torch.no_grad():
        trainer.demo(opt.demo_input_path)
    sys.exit(0)

if opt.run_single_eval:
    with torch.no_grad():
        trainer.test_epoch()
    sys.exit(0)

for epoch in range(trainer.epoch, opt.nepoch):
    trainer.train_epoch()
    with torch.no_grad():
        trainer.test_epoch()
    trainer.dump_stats()
    trainer.increment_epoch()
    trainer.save_network()

yellow_print(f"Visdom url http://localhost:{trainer.opt.visdom_port}/")
yellow_print(
    f"Netvision report url http://localhost:{trainer.opt.http_port}/{trainer.opt.dir_name}/index.html"
)
yellow_print(
    f"Training time {(time.time() - trainer.start_time)//60} minutes.")