Пример #1
0
def main_NeuMF(hyper_params, gpu_id=None):
    from pytorch_models.NeuMF import GMF, MLP, NeuMF
    from data import load_data
    from eval import evaluate, eval_ranking
    from utils import load_user_item_counts, is_cuda_available
    from utils import xavier_init, log_end_epoch
    from loss import MSELoss
    import torch

    user_count, item_count = load_user_item_counts(hyper_params)
    train_reader, test_reader, val_reader, hyper_params = load_data(
        hyper_params)
    start_time = time.time()

    initial_path = hyper_params['model_path']

    # Pre-Training the GMF Model
    hyper_params['model_path'] = initial_path + "_gmf"
    gmf_model = GMF(hyper_params)
    if is_cuda_available: gmf_model = gmf_model.cuda()
    xavier_init(gmf_model)
    gmf_model = train_complete(hyper_params, GMF, train_reader, val_reader,
                               user_count, item_count, gmf_model)

    # Pre-Training the MLP Model
    hyper_params['model_path'] = initial_path + "_mlp"
    mlp_model = MLP(hyper_params)
    if is_cuda_available: mlp_model = mlp_model.cuda()
    xavier_init(mlp_model)
    mlp_model = train_complete(hyper_params, MLP, train_reader, val_reader,
                               user_count, item_count, mlp_model)

    # Training the final NeuMF Model
    hyper_params['model_path'] = initial_path
    model = NeuMF(hyper_params)
    if is_cuda_available: model = model.cuda()
    model.init(gmf_model, mlp_model)
    model = train_complete(hyper_params, NeuMF, train_reader, val_reader,
                           user_count, item_count, model)

    # Evaluating the final model for MSE on test-set
    criterion = MSELoss(hyper_params)
    metrics, user_count_mse_map, item_count_mse_map = evaluate(model,
                                                               criterion,
                                                               test_reader,
                                                               hyper_params,
                                                               user_count,
                                                               item_count,
                                                               review=False)

    # Evaluating the final model for HR@1 on test-set
    metrics.update(eval_ranking(model, test_reader, hyper_params,
                                review=False))

    log_end_epoch(hyper_params,
                  metrics,
                  'final', (time.time() - start_time),
                  metrics_on='(TEST)')

    return metrics, user_count_mse_map, item_count_mse_map
Пример #2
0
def main(args):
    network_model = Model(pool=args.pool)
    if torch.cuda.is_available():
        network_model = network_model.cuda()

    if args.loss == 'L2Loss':
        criterion = MSELoss()
    elif 'CrossEntropy' in args.loss:
        criterion = MaxLoss(args.loss[13:])
    else:
        raise ValueError

    if args.lp_norm == 'None':
        regularizer = None
    else:
        regularizer = LpNorm(args.lp_norm, args.lp_norm_factor)

    optimizer = torch.optim.Adam(network_model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     args.lr_milestone,
                                                     gamma=0.5)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_set = SvhnDataset(root=args.root, train=True, transform=transform)
    test_set = SvhnDataset(root=args.root, train=False, transform=transform)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)
    test_loader = DataLoader(test_set,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             pin_memory=True)

    if args.evaluate:
        test(network_model, test_loader)
        return

    train(network_model, train_loader, test_loader, optimizer, scheduler,
          criterion, regularizer, args)
Пример #3
0
def build_loss_from_cfg(config):
    """Builds loss function with specific configuration.
    Args:
        config: the configuration.

    Returns:
        A nn.Module loss.
    """
    if config.NAME == 'cross_entropy':
        # return CrossEntropyLoss(ignore_index=config.IGNORE, reduction='mean')
        return RegularCE(ignore_label=config.IGNORE)
    elif config.NAME == 'ohem':
        return OhemCE(ignore_label=config.IGNORE,
                      threshold=config.THRESHOLD,
                      min_kept=config.MIN_KEPT)
    elif config.NAME == 'hard_pixel_mining':
        return DeepLabCE(ignore_label=config.IGNORE,
                         top_k_percent_pixels=config.TOP_K_PERCENT)
    elif config.NAME == 'mse':
        return MSELoss(reduction=config.REDUCTION)
    elif config.NAME == 'l1':
        return L1Loss(reduction=config.REDUCTION)
    else:
        raise ValueError('Unknown loss type: {}'.format(config.NAME))
    def __init__(self, config):
        super().__init__(config)

        # define models (generator and discriminator)
        self.h2l_G = HighToLowGenerator()
        self.h2l_D = HighToLowDiscriminator()
        self.l2h_G = LowToHighGenerator()
        self.l2h_D = LowToHighDiscriminator()

        # define loss
        #self.loss = GANLoss()
        #self.loss = HingeEmbeddingLoss()
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss()
        self.criterion_MSE = MSELoss()

        # define optimizers for both generator and discriminator
        self.l2h_optimG = torch.optim.Adam(self.l2h_G.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.l2h_optimD = torch.optim.Adam(self.l2h_D.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.h2l_optimG = torch.optim.Adam(self.h2l_G.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.h2l_optimD = torch.optim.Adam(self.h2l_D.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_mean_iou = 0

        self.real_label = 1
        self.fake_label = -1

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        # set the manual seed for torch
        self.manual_seed = random.randint(1, 10000)
        self.logger.info('seed:{}'.format(self.manual_seed))
        random.seed(self.manual_seed)

        self.test_file = self.config.output_path
        if not os.path.exists(self.test_file):
            os.makedirs(self.test_file)

        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            torch.cuda.manual_seed_all(self.manual_seed)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU***** ")

        self.l2h_G = self.l2h_G.to(self.device)
        self.l2h_D = self.l2h_D.to(self.device)
        self.h2l_G = self.h2l_G.to(self.device)
        self.h2l_D = self.h2l_D.to(self.device)
        self.criterion_GAN = self.criterion_GAN.to(self.device)
        self.criterion_MSE = self.criterion_MSE.to(self.device)

        # Summary Writer
        self.summary_writer_l2h = SummaryWriter(
            log_dir=self.config.summary_dir_l2h, comment='Low-To-High GAN')
        self.summary_writer_h2l = SummaryWriter(
            log_dir=self.config.summary_dir_h2l, comment='High-To-Low GAN')
class Combined_GAN(BaseAgent):
    def __init__(self, config):
        super().__init__(config)

        # define models (generator and discriminator)
        self.h2l_G = HighToLowGenerator()
        self.h2l_D = HighToLowDiscriminator()
        self.l2h_G = LowToHighGenerator()
        self.l2h_D = LowToHighDiscriminator()

        # define loss
        #self.loss = GANLoss()
        #self.loss = HingeEmbeddingLoss()
        self.criterion_GAN = torch.nn.BCEWithLogitsLoss()
        self.criterion_MSE = MSELoss()

        # define optimizers for both generator and discriminator
        self.l2h_optimG = torch.optim.Adam(self.l2h_G.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.l2h_optimD = torch.optim.Adam(self.l2h_D.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.h2l_optimG = torch.optim.Adam(self.h2l_G.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))
        self.h2l_optimD = torch.optim.Adam(self.h2l_D.parameters(),
                                           lr=self.config.learning_rate,
                                           betas=(self.config.beta1,
                                                  self.config.beta2))

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_valid_mean_iou = 0

        self.real_label = 1
        self.fake_label = -1

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.config.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.config.cuda

        # set the manual seed for torch
        self.manual_seed = random.randint(1, 10000)
        self.logger.info('seed:{}'.format(self.manual_seed))
        random.seed(self.manual_seed)

        self.test_file = self.config.output_path
        if not os.path.exists(self.test_file):
            os.makedirs(self.test_file)

        if self.cuda:
            self.device = torch.device("cuda")
            torch.cuda.set_device(self.config.gpu_device)
            torch.cuda.manual_seed_all(self.manual_seed)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
            print_cuda_statistics()
        else:
            self.device = torch.device("cpu")
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU***** ")

        self.l2h_G = self.l2h_G.to(self.device)
        self.l2h_D = self.l2h_D.to(self.device)
        self.h2l_G = self.h2l_G.to(self.device)
        self.h2l_D = self.h2l_D.to(self.device)
        self.criterion_GAN = self.criterion_GAN.to(self.device)
        self.criterion_MSE = self.criterion_MSE.to(self.device)

        # Summary Writer
        self.summary_writer_l2h = SummaryWriter(
            log_dir=self.config.summary_dir_l2h, comment='Low-To-High GAN')
        self.summary_writer_h2l = SummaryWriter(
            log_dir=self.config.summary_dir_h2l, comment='High-To-Low GAN')

    def load_checkpoint(self, file_name, model):
        if model == 'l2h':
            checkpoint_dir = self.config.checkpoint_l2h_dir
        elif model == 'h2l':
            checkpoint_dir = self.config.checkpoint_h2l_dir
        elif model == 'combined':
            checkpoint_dir = self.config.checkpoint_combined_dir

        filename = checkpoint_dir + file_name
        try:
            self.logger.info("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.manual_seed = checkpoint['manual_seed']

            if model == 'h2l':
                self.h2l_G.load_state_dict(checkpoint['h2l_G_state_dict'])
                self.h2l_optimG.load_state_dict(checkpoint['h2l_G_optimizer'])
                self.h2l_D.load_state_dict(checkpoint['h2l_D_state_dict'])
                self.h2l_optimD.load_state_dict(checkpoint['h2l_D_optimizer'])

            elif model == 'l2h':
                self.l2h_G.load_state_dict(checkpoint['l2h_G_state_dict'])
                self.l2h_optimG.load_state_dict(checkpoint['l2h_G_optimizer'])
                self.l2h_D.load_state_dict(checkpoint['l2h_D_state_dict'])
                self.l2h_optimD.load_state_dict(checkpoint['l2h_D_optimizer'])

            elif model == 'combined':
                self.h2l_G.load_state_dict(checkpoint['h2l_G_state_dict'])
                self.h2l_optimG.load_state_dict(checkpoint['h2l_G_optimizer'])
                self.h2l_D.load_state_dict(checkpoint['h2l_D_state_dict'])
                self.h2l_optimD.load_state_dict(checkpoint['h2l_D_optimizer'])

                self.l2h_G.load_state_dict(checkpoint['l2h_G_state_dict'])
                self.l2h_optimG.load_state_dict(checkpoint['l2h_G_optimizer'])
                self.l2h_D.load_state_dict(checkpoint['l2h_D_state_dict'])
                self.l2h_optimD.load_state_dict(checkpoint['l2h_D_optimizer'])

        except OSError:
            self.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    checkpoint_dir))
            self.logger.info("**First time to train**")

    def save_checkpoint(self, file_name, model, is_best=0):
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'manual_seed': self.manual_seed
        }
        if model == 'l2h':
            state['l2g_G_state_dict'] = self.l2h_G.state_dict()
            state['l2h_G_optimizer'] = self.l2h_optimG.state_dict()
            state['l2h_D_state_dict'] = self.l2h_D.state_dict()
            state['l2h_D_optimizer'] = self.l2h_optimD.state_dict()

            checkpoint_dir = self.config.checkpoint_l2h_dir

        elif model == 'h2l':
            state['h2l_G_state_dict'] = self.h2l_G.state_dict()
            state['h2l_G_optimizer'] = self.h2l_optimG.state_dict()
            state['h2l_D_state_dict'] = self.h2l_D.state_dict()
            state['h2l_D_optimizer'] = self.h2l_optimD.state_dict()

            checkpoint_dir = self.config.checkpoint_h2l_dir

        elif model == 'combined':
            state['l2h_G_state_dict'] = self.l2h_G.state_dict()
            state['l2h_G_optimizer'] = self.l2h_optimG.state_dict()
            state['l2h_D_state_dict'] = self.l2h_D.state_dict()
            state['l2h_D_optimizer'] = self.l2h_optimD.state_dict()

            state['h2l_G_state_dict'] = self.h2l_G.state_dict()
            state['h2l_G_optimizer'] = self.h2l_optimG.state_dict()
            state['h2l_D_state_dict'] = self.h2l_D.state_dict()
            state['h2l_D_optimizer'] = self.h2l_optimD.state_dict()

            checkpoint_dir = self.config.checkpoint_combined_dir
        # Save the state
        torch.save(state, checkpoint_dir + file_name)
        # If it is the best copy it to another file 'model_best.pth.tar'
        if is_best:
            shutil.copyfile(checkpoint_dir + file_name,
                            checkpoint_dir + '_best.pth.tar')

    def run(self):
        """
        This function will the operator
        :return:
        """
        try:
            self.train()
        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        # Model Loading from the latest checkpoint if not found start from scratch.
        self.load_checkpoint(self.config.checkpoint_file_h2l, 'h2l')
        if self.current_epoch <= 200:
            self.load_checkpoint(self.config.checkpoint_file_l2h, 'l2h')
        elif self.current_epoch > 200:
            self.load_checkpoint(self.config.checkpoint_file_combined,
                                 'combined')
        if self.current_epoch != 0 and self.current_epoch <= 200:
            self.logger.info(
                "Checkpoint loaded successfully from '{}' and '{}' at (epoch {}) at (iteration {})\n"
                .format(self.config.checkpoint_l2h_dir,
                        self.config.checkpoint_h2l_dir, self.current_epoch,
                        self.current_iteration))

        for epoch in range(self.current_epoch, self.config.max_epoch):
            self.current_epoch = epoch
            if epoch <= 200:
                self.train_one_epoch_h2l()
                self.train_one_epoch_l2h()
                self.save_checkpoint(self.config.checkpoint_file_l2h, 'l2h')
                self.save_checkpoint(self.config.checkpoint_file_h2l, 'h2l')
            else:
                self.train_one_epoch_combined()
                self.save_checkpoint(self.config.checkpoint_file_combined,
                                     'combined')

    def to_var(self, data):
        real_cpu = data
        batchsize = real_cpu.size(0)
        inp = Variable(real_cpu.cuda())
        return inp, batchsize

    def train_one_epoch_h2l(self):
        test_loader = get_loader(self.config.HighToLow_hr_datapath,
                                 self.config.HighToLow_lr_datapath,
                                 self.config.batch_size)

        self.h2l_G.train()
        self.h2l_D.train()

        for curr_it, data_dict in enumerate(test_loader):
            data_low = data_dict['lr']
            data_high = data_dict['hr']
            data_input_low, batchsize = self.to_var(data_low)
            data_input_high, _ = self.to_var(data_high)

            y = torch.randn(data_low.size(0), )
            y, _ = self.to_var(y)

            ##################
            #  Train Generator
            ##################

            self.h2l_optimG.zero_grad()

            # Generate a high resolution image from low resolution input
            noise = torch.randn(data_high.size(0), 1)
            noise, _ = self.to_var(noise)
            gen_hr = self.h2l_G(data_input_high, noise)

            # Measure pixel-wise loss against ground truth
            loss_pixel = self.criterion_MSE(gen_hr, data_input_low)

            # Extract validity predictions from discriminator
            pred_real = self.h2l_D(data_input_high).detach()
            pred_fake = self.h2l_D(gen_hr)

            # Adversarial loss (relativistic average GAN)
            y.fill_(self.real_label)
            loss_G_GAN = self.criterion_GAN(
                pred_fake - pred_real.mean(0, keepdim=True), y)

            # Total generator loss
            loss_G = (self.config.beta * loss_G_GAN) + (self.config.alpha *
                                                        loss_pixel)

            loss_G.backward(retain_graph=True)
            self.h2l_optimG.step()

            ######################
            #  Train Discriminator
            ######################

            self.h2l_optimD.zero_grad()

            # Adversarial loss for real and fake images (relativistic average GAN)
            pred_real = self.h2l_D(data_input_high)
            y.fill_(self.real_label)
            loss_D_real = self.criterion_GAN(
                pred_real - pred_fake.mean(0, keepdim=True), y)
            loss_D_real.backward(retain_graph=True)

            pred_fake = self.h2l_D(gen_hr.detach())
            y.fill_(self.fake_label)
            loss_D_fake = self.criterion_GAN(
                pred_fake - pred_real.mean(0, keepdim=True), y)
            loss_D_fake.backward()
            # Total loss
            loss_D = (loss_D_real + loss_D_fake) / 2

            #loss_D.backward()
            self.h2l_optimD.step()

            self.current_iteration += 1

            self.summary_writer_h2l.add_scalar("epoch/Generator_loss",
                                               loss_G.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss_real",
                                               loss_D_real.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss_fake",
                                               loss_D_fake.item(),
                                               self.current_iteration)

            path = os.path.join(
                self.test_file, 'batch' + str(curr_it) + '_epoch' +
                str(self.current_epoch) + '_h2l.jpg')
            vutils.save_image(gen_hr.data, path, normalize=True)

            # --------------
            #  Log Progress
            # --------------

            self.logger.info(
                "High-To-Low GAN: [Epoch %d/%d] [Batch %d/%d] [D loss: %f, real: %f, fake: %f] [G loss: %f, adv: %f, pixel: %f]"
                % (
                    self.current_epoch + 1,
                    self.config.max_epoch,
                    curr_it + 1,
                    len(test_loader),
                    loss_D.item(),
                    loss_D_real.item(),
                    loss_D_fake.item(),
                    loss_G.item(),
                    loss_G_GAN.item(),
                    loss_pixel.item(),
                ))

    def train_one_epoch_l2h(self):
        test_loader = get_loader(self.config.LowToHigh_datapath, None,
                                 self.config.batch_size)

        self.l2h_G.train()
        self.l2h_D.train()

        for curr_it, data_dict in enumerate(test_loader):
            data_low = data_dict['img16']
            data_high = data_dict['img64']
            data_input_low, batchsize = self.to_var(data_low)
            data_input_high, _ = self.to_var(data_high)

            y = torch.randn(data_low.size(0), )
            y, _ = self.to_var(y)

            ##################
            #  Train Generator
            ##################

            self.l2h_optimG.zero_grad()

            # Generate a high resolution image from low resolution input
            gen_hr = self.l2h_G(data_input_low)

            # Measure pixel-wise loss against ground truth
            loss_pixel = self.criterion_MSE(gen_hr, data_input_high)

            # Extract validity predictions from discriminator
            pred_real = self.l2h_D(data_input_high).detach()
            pred_fake = self.l2h_D(gen_hr)

            # Adversarial loss (relativistic average GAN)
            y.fill_(self.real_label)
            loss_G_GAN = self.criterion_GAN(
                pred_fake - pred_real.mean(0, keepdim=True), y)

            # Total generator loss
            loss_G = (self.config.beta * loss_G_GAN) + (self.config.alpha *
                                                        loss_pixel)

            loss_G.backward(retain_graph=True)
            self.l2h_optimG.step()

            ######################
            #  Train Discriminator
            ######################

            self.l2h_optimD.zero_grad()

            # Adversarial loss for real and fake images (relativistic average GAN)
            pred_real = self.l2h_D(data_input_high)
            y.fill_(self.real_label)
            loss_D_real = self.criterion_GAN(
                pred_real - pred_fake.mean(0, keepdim=True), y)
            loss_D_real.backward(retain_graph=True)

            pred_fake = self.l2h_D(gen_hr.detach())
            y.fill_(self.fake_label)
            loss_D_fake = self.criterion_GAN(
                pred_fake - pred_real.mean(0, keepdim=True), y)
            loss_D_fake.backward()
            # Total loss
            loss_D = (loss_D_real + loss_D_fake) / 2

            self.l2h_optimD.step()

            self.current_iteration += 1

            self.summary_writer_l2h.add_scalar("epoch/Generator_loss",
                                               loss_G.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss_real",
                                               loss_D_real.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss_fake",
                                               loss_D_fake.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss",
                                               loss_D.item(),
                                               self.current_iteration)

            path = os.path.join(
                self.test_file, 'batch' + str(curr_it) + '_epoch' +
                str(self.current_epoch) + '_l2h.jpg')
            vutils.save_image(gen_hr.data, path, normalize=True)

            # --------------
            #  Log Progress
            # --------------

            self.logger.info(
                "Low-To-High GAN: [Epoch %d/%d] [Batch %d/%d] [D loss: %f, real: %f, fake: %f] [G loss: %f, adv: %f, pixel: %f]"
                % (
                    self.current_epoch + 1,
                    self.config.max_epoch,
                    curr_it,
                    len(test_loader),
                    loss_D.item(),
                    loss_D_real.item(),
                    loss_D_fake.item(),
                    loss_G.item(),
                    loss_G_GAN.item(),
                    loss_pixel.item(),
                ))

    def train_one_epoch_combined(self):
        test_loader = get_loader(self.config.HighToLow_hr_datapath,
                                 self.config.HighToLow_lr_datapath,
                                 self.config.batch_size)

        self.h2l_G.train()
        self.h2l_D.train()
        self.l2h_G.train()
        self.l2h_D.train()

        for curr_it, data_dict in enumerate(test_loader):
            data_low = data_dict['lr']
            data_high = data_dict['hr']
            data_input_low, batchsize = self.to_var(data_low)
            data_input_high, _ = self.to_var(data_high)

            y = torch.randn(data_low.size(0), )
            y, _ = self.to_var(y)

            ##############################
            #  Train High-To-Low Generator
            ##############################

            self.h2l_optimG.zero_grad()

            # Generate a high resolution image from low resolution input
            noise = torch.randn(data_high.size(0), 1)
            noise, _ = self.to_var(noise)
            h2l_gen_hr = self.h2l_G(data_input_high, noise)

            # Measure pixel-wise loss against ground truth
            h2l_loss_pixel = self.criterion_MSE(h2l_gen_hr, data_input_low)

            # Extract validity predictions from discriminator
            h2l_pred_real = self.h2l_D(data_input_high).detach()
            h2l_pred_fake = self.h2l_D(h2l_gen_hr)

            # Adversarial loss (relativistic average GAN)
            y.fill_(self.real_label)
            h2l_loss_G_GAN = self.criterion_GAN(
                h2l_pred_fake - h2l_pred_real.mean(0, keepdim=True), y)

            # Total generator loss
            h2l_loss_G = (self.config.beta * h2l_loss_G_GAN) + (
                self.config.alpha * h2l_loss_pixel)

            h2l_loss_G.backward(retain_graph=True)
            self.h2l_optimG.step()

            ##################################
            #  Train High-To-Low Discriminator
            ##################################

            self.h2l_optimD.zero_grad()

            # Adversarial loss for real and fake images (relativistic average GAN)
            h2l_pred_real = self.h2l_D(data_input_high)
            y.fill_(self.real_label)
            h2l_loss_D_real = self.criterion_GAN(
                h2l_pred_real - h2l_pred_fake.mean(0, keepdim=True), y)
            h2l_loss_D_real.backward(retain_graph=True)

            h2l_pred_fake = self.h2l_D(h2l_gen_hr.detach())
            y.fill_(self.fake_label)
            h2l_loss_D_fake = self.criterion_GAN(
                h2l_pred_fake - h2l_pred_real.mean(0, keepdim=True), y)
            h2l_loss_D_fake.backward()
            # Total loss
            h2l_loss_D = (h2l_loss_D_real + h2l_loss_D_fake) / 2

            self.h2l_optimD.step()

            self.current_iteration += 1

            self.summary_writer_h2l.add_scalar("epoch/Generator_loss",
                                               h2l_loss_G.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss_real",
                                               h2l_loss_D_real.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss_fake",
                                               h2l_loss_D_fake.item(),
                                               self.current_iteration)
            self.summary_writer_h2l.add_scalar("epoch/Discriminator_loss",
                                               h2l_loss_D.item(),
                                               self.current_iteration)

            path = os.path.join(
                self.test_file, 'batch' + str(curr_it) + '_epoch' +
                str(self.current_epoch) + '_combined_intermidiate.jpg')
            vutils.save_image(h2l_gen_hr.data, path, normalize=True)

            # --------------
            #  Log Progress
            # --------------

            self.logger.info(
                "Combined model: High-To-Low GAN: [Epoch %d/%d] [Batch %d/%d] [D loss: %f, real: %f, fake: %f] [G loss: %f, adv: %f, pixel: %f]"
                % (
                    self.current_epoch + 1,
                    self.config.max_epoch,
                    curr_it + 1,
                    len(test_loader),
                    h2l_loss_D.item(),
                    h2l_loss_D_real.item(),
                    h2l_loss_D_fake.item(),
                    h2l_loss_G.item(),
                    h2l_loss_G_GAN.item(),
                    h2l_loss_pixel.item(),
                ))

            data_input_low = h2l_gen_hr

            y = torch.randn(data_input_low.size(0), )
            y, _ = self.to_var(y)

            ##############################
            #  Train Low-To-High Generator
            ##############################

            self.l2h_optimG.zero_grad()

            # Generate a high resolution image from low resolution input
            l2h_gen_hr = self.l2h_G(data_input_low)

            # Measure pixel-wise loss against ground truth
            l2h_loss_pixel = self.criterion_MSE(l2h_gen_hr, data_input_high)

            # Extract validity predictions from discriminator
            l2h_pred_real = self.l2h_D(data_input_high).detach()
            l2h_pred_fake = self.l2h_D(l2h_gen_hr)

            # Adversarial loss (relativistic average GAN)
            y.fill_(self.real_label)
            l2h_loss_G_GAN = self.criterion_GAN(
                l2h_pred_fake - l2h_pred_real.mean(0, keepdim=True), y)

            # Total generator loss
            l2h_loss_G = (self.config.beta * l2h_loss_G_GAN) + (
                self.config.alpha * l2h_loss_pixel)

            l2h_loss_G.backward(retain_graph=True)
            self.l2h_optimG.step()

            ##################################
            #  Train Low-To-High Discriminator
            ##################################

            self.l2h_optimD.zero_grad()

            # Adversarial loss for real and fake images (relativistic average GAN)
            l2h_pred_real = self.l2h_D(data_input_high)
            y.fill_(self.real_label)
            l2h_loss_D_real = self.criterion_GAN(
                l2h_pred_real - l2h_pred_fake.mean(0, keepdim=True), y)
            l2h_loss_D_real.backward(retain_graph=True)

            l2h_pred_fake = self.l2h_D(l2h_gen_hr.detach())
            y.fill_(self.fake_label)
            l2h_loss_D_fake = self.criterion_GAN(
                l2h_pred_fake - l2h_pred_real.mean(0, keepdim=True), y)
            l2h_loss_D_fake.backward()
            # Total loss
            l2h_loss_D = (l2h_loss_D_real + l2h_loss_D_fake) / 2

            self.l2h_optimD.step()

            self.current_iteration += 1

            self.summary_writer_l2h.add_scalar("epoch/Generator_loss",
                                               l2h_loss_G.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss_real",
                                               l2h_loss_D_real.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss_fake",
                                               l2h_loss_D_fake.item(),
                                               self.current_iteration)
            self.summary_writer_l2h.add_scalar("epoch/Discriminator_loss",
                                               l2h_loss_D.item(),
                                               self.current_iteration)

            path = os.path.join(
                self.test_file, 'batch' + str(curr_it) + '_epoch' +
                str(self.current_epoch) + '_combined_final.jpg')
            vutils.save_image(l2h_gen_hr.data, path, normalize=True)

            # --------------
            #  Log Progress
            # --------------

            self.logger.info(
                "Combined model: Low-To-High GAN: [Epoch %d/%d] [Batch %d/%d] [D loss: %f, real: %f, fake: %f] [G loss: %f, adv: %f, pixel: %f]"
                % (
                    self.current_epoch + 1,
                    self.config.max_epoch,
                    curr_it,
                    len(test_loader),
                    l2h_loss_D.item(),
                    l2h_loss_D_real.item(),
                    l2h_loss_D_fake.item(),
                    l2h_loss_G.item(),
                    l2h_loss_G_GAN.item(),
                    l2h_loss_pixel.item(),
                ))

    def validate(self):
        pass

    def finalize(self):
        """
        Finalize all the operations of the 2 Main classes of the process the operator and the data loader
        :return:
        """
        self.logger.info(
            "Please wait while finalizing the operation.. Thank you")
        self.save_checkpoint()
        self.summary_writer.export_scalars_to_json("{}all_scalars.json".format(
            self.config.summary_dir))
        self.summary_writer.close()
        self.dataloader.finalize()
Пример #6
0
def main(hyper_params=None, pretrain_full_info=False, train_regressor=False):
    # If custom hyper_params are not passed, load from hyper_params.py
    if hyper_params is None: from hyper_params import hyper_params
    else: print("Using passed hyper-parameters..")

    # Initialize a tensorboard writer
    global writer
    path = hyper_params['tensorboard_path']
    writer = SummaryWriter(path, flush_secs=20)

    # Loading data
    if pretrain_full_info == True:
        train_reader, test_reader = load_data_full_info(hyper_params)
    else:
        train_reader, test_reader, val_reader = load_data(
            hyper_params, train_regressor=train_regressor)
        hyper_params['all_ks'] = get_all(train_reader)  # For MinSup evaluation

    file_write(hyper_params,
               "\n\nSimulation run on: " + str(dt.datetime.now()) + "\n\n")
    file_write(hyper_params, "Data reading complete!")
    file_write(hyper_params,
               "Number of train batches: {:4d}".format(len(train_reader)))
    if pretrain_full_info == False:
        file_write(hyper_params,
                   "Number of val batches: {:4d}".format(len(val_reader)))
    file_write(hyper_params,
               "Number of test batches: {:4d}".format(len(test_reader)))
    if 'all_ks' in hyper_params:
        file_write(
            hyper_params,
            "MinSup estimated k: " + str(hyper_params['all_ks']) + "\n\n")

    # Creating model
    if train_regressor: model = RegressionModelCifar(hyper_params)
    else: model = ModelCifar(hyper_params)
    if is_cuda_available: model.cuda()

    # Loss function
    if pretrain_full_info: criterion = nn.CrossEntropyLoss()
    elif train_regressor: criterion = MSELoss(hyper_params)
    else: criterion = CustomLoss(hyper_params)

    # Optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=hyper_params['lr'],
                                momentum=0.9,
                                weight_decay=hyper_params['weight_decay'])

    file_write(hyper_params, str(model))
    if pretrain_full_info == True:
        file_write(hyper_params, "Pre-training model on full information..")
    file_write(hyper_params, "\nModel Built!\nStarting Training...\n")

    best_metrics_val = None
    validate_on = hyper_params[
        'validate_using']  # Estimator to chose best model (validation)
    if pretrain_full_info == True:
        validate_on = "Accuracy"  # Since full-information

    try:
        for epoch in range(1, hyper_params['epochs'] + 1):
            epoch_start_time = time.time()
            metrics_train, metrics_val = None, None

            if pretrain_full_info == True:
                metrics_train = train_full_info(model, criterion, optimizer,
                                                train_reader, hyper_params,
                                                epoch)
                # Note that the metrics_train calculated is different from the actual model performance
                # Because the accuracy is calculated WHILE IT IS BEING TRAINED.
                # If we were to re-calculate the model performance keeping model parameters fixed:
                # we would get a different (most likely better) Accuracy.

                # Don't validate for logging policy. Just store the model at every epoch.
                torch.save(model.state_dict(), hyper_params['model_file'])
            else:
                metrics_train = train(model, criterion, optimizer,
                                      train_reader, hyper_params, epoch)
                # Calulating the metrics on the validation set
                metrics_val = evaluate(model,
                                       criterion,
                                       val_reader,
                                       hyper_params,
                                       eval_estimators=True,
                                       test_set=False)

                # Validate
                if best_metrics_val is None: best_metrics_val = metrics_val
                elif metrics_val[validate_on] >= best_metrics_val[validate_on]:
                    best_metrics_val = metrics_val

                # Save model if current is best epoch
                if metrics_val[validate_on] == best_metrics_val[validate_on]:
                    torch.save(model.state_dict(), hyper_params['model_file'])

            metrics_train = None  # Don't print train metrics, since already printing in tqdm bar
            log_end_epoch(hyper_params, epoch, epoch_start_time, writer,
                          metrics_train, metrics_val)

    except KeyboardInterrupt:
        print('Exiting from training early')

    # Evaluate best saved model
    model = ModelCifar(hyper_params)
    if is_cuda_available: model.cuda()
    model.load_state_dict(torch.load(hyper_params['model_file']))
    model.eval()

    metrics_train = None
    metrics_test = evaluate(model,
                            criterion,
                            test_reader,
                            hyper_params,
                            eval_estimators=False,
                            test_set=True)

    file_write(hyper_params, "Final model performance on test-set:")
    log_end_epoch(hyper_params,
                  hyper_params['epochs'] + 1,
                  time.time(),
                  writer,
                  metrics_train,
                  metrics_test,
                  test=True)

    writer.close()

    return metrics_test
Пример #7
0
def main_pytorch(hyper_params, gpu_id=None):
    from data import load_data
    from eval import evaluate, eval_ranking
    from utils import load_obj, is_cuda_available
    from utils import load_user_item_counts, xavier_init, log_end_epoch
    from loss import MSELoss

    if hyper_params['model_type'] in ['deepconn', 'deepconn++']:
        from pytorch_models.DeepCoNN import DeepCoNN as Model
    elif hyper_params['model_type'] in ['transnet', 'transnet++']:
        from pytorch_models.TransNet import TransNet as Model
    elif hyper_params['model_type'] in ['NARRE']:
        from pytorch_models.NARRE_modify import NARRE as Model
    elif hyper_params['model_type'] in ['bias_only', 'MF', 'MF_dot']:
        from pytorch_models.MF import MF as Model

    import torch

    # Load the data readers
    user_count, item_count = load_user_item_counts(hyper_params)
    if hyper_params['model_type'] not in [
            'bias_only', 'MF', 'MF_dot', 'NeuMF'
    ]:
        review_based_model = True
        try:
            from data_fast import load_data_fast
            train_reader, test_reader, val_reader, hyper_params = load_data_fast(
                hyper_params)
            print(
                "Loaded preprocessed epoch files. Should be faster training..."
            )
        except Exception as e:
            print("Tried loading preprocessed epoch files, but failed.")
            print(
                "Please consider running `prep_all_data.sh` to make quick data for DeepCoNN/TransNet/NARRE."
            )
            print("This will save large amounts of run time.")
            print("Loading standard (slower) data..")
            train_reader, test_reader, val_reader, hyper_params = load_data(
                hyper_params)
    else:
        review_based_model = False
        train_reader, test_reader, val_reader, hyper_params = load_data(
            hyper_params)

    # Initialize the model
    model = Model(hyper_params)
    if is_cuda_available: model = model.cuda()
    xavier_init(model)

    # Train the model
    start_time = time.time()
    model = train_complete(hyper_params,
                           Model,
                           train_reader,
                           val_reader,
                           user_count,
                           item_count,
                           model,
                           review=review_based_model)

    # Calculating MSE on test-set
    print("Calculating MSE on test-set")
    criterion = MSELoss(hyper_params)
    metrics, user_count_mse_map, item_count_mse_map = evaluate(
        model,
        criterion,
        test_reader,
        hyper_params,
        user_count,
        item_count,
        review=review_based_model)
    print("Calculating HR@1 on test-set")
    # Calculating HR@1 on test-set
    _, test_reader2, _, _ = load_data(
        hyper_params)  # Needs default slow reader
    metrics.update(
        eval_ranking(model,
                     test_reader2,
                     hyper_params,
                     review=review_based_model))

    log_end_epoch(hyper_params,
                  metrics,
                  'final',
                  time.time() - start_time,
                  metrics_on='(TEST)')

    return metrics, user_count_mse_map, item_count_mse_map
Пример #8
0
def train_complete(hyper_params,
                   Model,
                   train_reader,
                   val_reader,
                   user_count,
                   item_count,
                   model,
                   review=True):
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable

    from loss import MSELoss
    from eval import evaluate, eval_ranking
    from utils import file_write, is_cuda_available, load_obj, log_end_epoch, init_transnet_optim

    file_write(hyper_params['log_file'],
               "\n\nSimulation run on: " + str(dt.datetime.now()) + "\n\n")
    file_write(hyper_params['log_file'], "Data reading complete!")
    file_write(hyper_params['log_file'],
               "Number of train batches: {:4d}".format(len(train_reader)))
    file_write(hyper_params['log_file'],
               "Number of validation batches: {:4d}".format(len(val_reader)))

    criterion = MSELoss(hyper_params)

    if hyper_params['model_type'] in ['transnet', 'transnet++']:
        optimizer = init_transnet_optim(hyper_params, model)

    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=hyper_params['lr'],
                                     weight_decay=hyper_params['weight_decay'])

    file_write(hyper_params['log_file'], str(model))
    file_write(hyper_params['log_file'],
               "\nModel Built!\nStarting Training...\n")

    try:
        best_MSE = float(INF)

        for epoch in range(1, hyper_params['epochs'] + 1):
            epoch_start_time = time.time()

            # Training for one epoch
            metrics = train(model, criterion, optimizer, train_reader,
                            hyper_params)
            metrics['dataset'] = hyper_params['dataset']
            # log_end_epoch(hyper_params, metrics, epoch, time.time() - epoch_start_time, metrics_on = '(TRAIN)')

            # Calulating the metrics on the validation set
            metrics, _, _ = evaluate(model,
                                     criterion,
                                     val_reader,
                                     hyper_params,
                                     user_count,
                                     item_count,
                                     review=review)
            metrics['dataset'] = hyper_params['dataset']
            log_end_epoch(hyper_params,
                          metrics,
                          epoch,
                          time.time() - epoch_start_time,
                          metrics_on='(VAL)')

            # Save best model on validation set
            if metrics['MSE'] < best_MSE:
                print("Saving model...")
                torch.save(model.state_dict(), hyper_params['model_path'])
                best_MSE = metrics['MSE']

    except KeyboardInterrupt:
        print('Exiting from training early')

    # Load best model and return it for evaluation on test-set
    model = Model(hyper_params)
    if is_cuda_available: model = model.cuda()
    model.load_state_dict(torch.load(hyper_params['model_path']))
    model.eval()

    return model
dense_lr = 0.01
lr_decay_rate = 0.95
lr_decay_freq = 10

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(299),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
val_transform = transforms.Compose(
    [transforms.RandomResizedCrop(299),
     transforms.ToTensor()])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
corLoss = CORLoss()
mseLoss = MSELoss()
emdLoss = EMDLoss()
cepLoss = CEPLoss()
edwklLoss = EDWKLLoss()

model = ourNet()
model = model.to(device)

optimizer = optim.SGD([{'params': model.parameters(), 'lr': LR}], momentum=0.9)

trainset = AVADataset(csv_file=TRAIN_CSV_FILE,
                      root_dir=TRAIN_IMG_PATH,
                      transform=train_transform)
valset = AVADataset(csv_file=VAL_CSV_FILE,
                    root_dir=VAL_IMG_PATH,
                    transform=val_transform)
Пример #10
0
def foo(mod, op, d):
    if (op[0] == "linear"):
        xx = Linear(d)

    # rnncell, lstmcell, grucell
    elif (mod[0] in ["LSTMCell", "GRUCell"]) and (op[0] == "forward"):
        xx = RNNCell(d)

    elif op[0] in [
            "conv1d",
            "conv2d",
    ]:
        xx = Conv(d)

    elif (op[0] in Pointwise.ops):
        xx = Pointwise(d)

    elif (op[0] in Convert.ops):
        xx = Convert(d)

    elif op[0] in ["__matmul__", "matmul"]:
        xx = Matmul(d)

    elif op[0] == "embedding":
        xx = Embedding(d)

    #reduction
    elif op[0] == "sum":
        xx = Sum(d)

    elif op[0] == "mean":
        xx = Mean(d)

    elif op[0] == "norm":
        xx = Norm(d)

    elif op[0] == "dropout":
        xx = Dropout(d)

    #Index, Slice, Join, Mutate
    elif (op[0] == "cat"):
        xx = Cat(d)

    elif (op[0] == "reshape"):
        xx = Reshape(d)

    elif (op[0] == "masked_scatter_"):
        xx = MaskedScatter(d)

    elif (op[0] == "gather"):
        xx = Gather(d)

    elif (op[0] == "nonzero"):
        xx = Nonzero(d)

    elif (op[0] == "index_select"):
        xx = IndexSelect(d)

    elif (op[0] == "masked_select"):
        xx = MaskedSelect(d)

    #blas
    elif op[0] in ["addmm", "addmm_"]:
        xx = Addmm(d)

    elif op[0] == "mm":
        xx = Mm(d)

    elif op[0] == "bmm":
        xx = Bmm(d)

    #softmax
    elif op[0] == "softmax":
        xx = Softmax(d)

    elif op[0] == "log_softmax":
        xx = LogSoftmax(d)

    #loss
    elif op[0] == "mse_loss":
        xx = MSELoss(d)

    #optimizers
    elif op[0] == "adam":
        xx = Adam(d)

    #normalization
    elif op[0] == "batch_norm":
        xx = BatchNorm(d)

    #random
    elif op[0] == "randperm":
        xx = RandPerm(d)

    #misc
    elif op[0] == "copy_":
        xx = Copy(d)

    elif op[0] == "clone":
        xx = Clone(d)

    elif op[0] == "contiguous":
        xx = Contiguous(d)

    elif op[0] == "any":
        xx = Any(d)

    elif (op[0] in Activation.ops):
        xx = Activation(d)

    elif op[0] == "to":
        xx = Convert(d)

    else:
        xx = Foo(d)

    return xx
Пример #11
0
def train_model(model, train_input, train_target, nb_epochs, mini_batch_size, criterion=MSELoss(), eta=1e-3):
    model.reset_params()
    for e in range(nb_epochs):
        sum_loss = 0
        for b in range(0, train_input.size(0), mini_batch_size):
            # forward pass
            output = model.forward(train_input.narrow(0, b, mini_batch_size))
            loss = criterion.forward(output, train_target.narrow(0, b, mini_batch_size))
            sum_loss += loss.item()

            # backward pass
            model.reset_gradient()
            model.backward(criterion.backward(output, train_target.narrow(0, b, mini_batch_size)))
            model.update_params(eta)
def main(argv):
    TRAIN, NOISE_TYPES, IMAGE_SIZE, FRAME_SIZE, OVERLAY_SIZE, LATENT_CLEAN_SIZE, BATCH_SIZE, EPOCHS, TEST = arguments_parsing(argv)
    
    if TRAIN:
        print('model training with parameters:\n'+
              'noise types = {}\n'.format(NOISE_TYPES)+
              'image size = {}\n'.format(IMAGE_SIZE)+
              'frame size = {}\n'.format(FRAME_SIZE)+
              'overlay size = {}\n'.format(OVERLAY_SIZE)+
              'latent clean size = {}\n'.format(LATENT_CLEAN_SIZE)+
              'batch size = {}\n'.format(BATCH_SIZE)+
              'number of epochs = {}\n'.format(EPOCHS))
        
        # dataset table creating
        make_dataset_table(PATH_TO_DATA, NOISE_TYPES, PATH_TO_DATASET_TABLE)
        train_test_split(PATH_TO_DATASET_TABLE, test_size=0.2)

        # dataset and dataloader creating
        torch.manual_seed(0)
        transforms = [Compose([RandomHorizontalFlip(p=1.0), ToTensor()]),
                      Compose([RandomVerticalFlip(p=1.0), ToTensor()]),
                      Compose([ColorJitter(brightness=(0.9, 2.0), contrast=(0.9, 2.0)), ToTensor()])]

        train_dataset = []
        for transform in transforms:
            dataset = DenoisingDataset(dataset=pd.read_csv(PATH_TO_DATASET_TABLE),
                                       image_size=IMAGE_SIZE,
                                       frame_size=FRAME_SIZE,
                                       overlay_size=OVERLAY_SIZE,
                                       phase='train',
                                       transform=transform)
            train_dataset = ConcatDataset([train_dataset, dataset])

        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=BATCH_SIZE,
                                  shuffle=True, # can be set to True only for train loader
                                  num_workers=0)

        # model training
        model = AE(1, LATENT_CLEAN_SIZE)
        loss = SSIMLoss()
        latent_loss = MSELoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-3)
        model = train_model(model, train_loader,
                            loss, latent_loss,
                            optimizer,
                            epochs=EPOCHS,
                            device=DEVICE)

        # model saving
        path_to_model = './model' + '_{}'.format('_'.join([str(elem) for elem in NOISE_TYPES])) + '.pth'
        torch.save(model, path_to_model)
    
    if TEST:    
        # model loading
        path_to_model = './model' + '_{}'.format('_'.join([str(elem) for elem in NOISE_TYPES])) + '.pth'
        print('{} testing...\n'.format(os.path.basename(path_to_model)))
        model = torch.load(path_to_model)

        dataset=pd.read_csv(PATH_TO_DATASET_TABLE)
        test_dataset = dataset[dataset['phase']=='test']

        # model testing and results saving
        loss = SSIMLoss()
        latent_loss = MSELoss()
        print('{} evaluation on test images'.format(os.path.basename(path_to_model)))
        test_evaluation(model, test_dataset,
                        loss, latent_loss,
                        device=DEVICE)
        print()
        
        path_to_results = PATH_TO_RESULTS + '_{}'.format('_'.join([str(elem) for elem in NOISE_TYPES]))
        if not os.path.exists(path_to_results):
            os.makedirs(path_to_results)
        print('{} running and results saving'.format(os.path.basename(path_to_model)))
        test_model(model, test_dataset, path_to_results)
        
    print('process completed: OK')
Пример #13
0
from module import Linear, Sequential
from activation import Tanh, ReLU, LeakyReLU, PReLU
from optimizer import SGD
from loss import MSELoss, CrossEntropyLoss
from utils import gen_disc_set, plot_dataset, build_CV_sets, standardise_input, train, test

if __name__ == '__main__':

    lr = 0.01
    k_fold = 10
    CV_sets = build_CV_sets(k_fold, 1000)
    print('CV sets built.')
    test_input, test_target = gen_disc_set(1000)

    for criterion in [MSELoss(), CrossEntropyLoss()]:
        for mini_batch_size in [20]:
            for activation in [Tanh()]:

                print('***')
                print('Criterion: {}, mini_batch_size: {}, activation: {}.'.
                      format(criterion.name(), mini_batch_size,
                             activation.name()))
                print('***')

                training_time_acc = []
                test_error_acc = []
                for i in tqdm(range(k_fold), leave=False):

                    torch.manual_seed(2019)
                    model = Sequential([
Пример #14
0
def train(**kwargs):
    """
    训练模型
    """
    opt._parse(kwargs), opt._print_conf()
    vis = Visualizer(opt.env)

    # step1: data
    train_dataset = CornellDataset(opt.train_data_root, train=True)
    eval_dataset = CornellDataset(opt.train_data_root, train=False)
    train_dataloader = DataLoader(train_dataset,opt.batch_size, shuffle=True,num_workers=opt.num_workers)
    eval_dataloader = DataLoader(eval_dataset,opt.batch_size, shuffle=False,num_workers=opt.num_workers)

    # step2: configure model
    model = getattr(models, opt.model_name)() # getattr 获取对象属性值, 居然能把包里头的模块也当成包的属性
    if opt.load_model_path:
        checkpoints = torch.load(opt.load_model_path) # 保存时按字典保存,还能存其他参数
        opt.start_epoch = checkpoints['start_epoch'] 
        model.load_state_dict(checkpoints['state_dict'])
    if opt.multi_gpu: model = torch.nn.parallel.DataParallel(model).to(opt.device) # 单机多卡
    else: model.to(opt.device) # 单机单卡
        
    # step3: criterion and optimizer
    criterion = MSELoss() # 损失函数计算出来就在gpu上,不用cuda()
    lr=opt.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
    
    # step4: meters
    loss_meter = meter.AverageValueMeter()
    # 便于只保存有意义的参数,节省存储
    previous_loss = 1e10
    best_acc = 0.7

    model.train()
    with torch.autograd.set_detect_anomaly(True):
        for epoch in range(opt.start_epoch, opt.num_epoch):
            print('*' * 40), print(f'epoch {epoch}')
            loss_meter.reset()
            for step, (inputs, labels, imgIds) in tqdm(enumerate(train_dataloader)):
                inputs, labels = inputs.to(opt.device), labels.to(opt.device)
            
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                loss_meter.add(loss.item())

                if (step + 1)%opt.print_freq == 0:
                    vis.plot('loss', loss_meter.value()[0])

            # 验证validate评估eval 
            eval_acc = eval(model, eval_dataloader)
            print(f'[eval_acc:{eval_acc:.4f}]'), vis.plot('eval_acc', eval_acc)
            vis.log(f"[epoch:{epoch}  eval_acc:{eval_acc:.4f} lr:{lr}")

            if eval_acc >= best_acc:
                # 保存中间过程模型, 仅网络中的参数
                print(f'Saving epoch {epoch} model ...')
                state_dict = model.state_dict()
                if opt.multi_gpu: # DataParallel的model.state_dict的key会多'module.'7个字符,要去掉
                    for k in list(state_dict.keys()): # 不能直接实时取.keys,它在变
                        state_dict[k[7:]] = state_dict.pop(k)
                checkpoints = {'start_epoch':epoch+1, 'state_dict':state_dict} # epoch 0 训练好了,下次的start_epoch=1
                torch.save(checkpoints, f'checkpoints/{opt.model_name}_{opt.dataset_name}_lr{opt.lr}_ld{opt.lr_decay}_bs{opt.batch_size}_wd{opt.weight_decay}_epoch{epoch}_acc{eval_acc:.3f}.ckpt')

            # update learning rate
            if loss_meter.value()[0] > previous_loss:          
                lr = lr * opt.lr_decay
                # 第二种降低学习率的方法:不会有moment等信息的丢失
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            
            previous_loss = loss_meter.value()[0]

    print('Finished Training') # 注意,所有模型都是自定义格式,落地应用时 不能直接torch.load_state_dict