示例#1
0
    def train(self, train, valid, dim_size, r, lr, epoch_size, desc_step,
              early_stopping):
        # param config
        self.train_size = len(train["asin"])
        self.valid_size = len(valid["asin"])

        # data preparation
        # user feature and item feature
        self.U = np.random.rand(self.user_size, dim_size)
        self.I = np.random.rand(self.item_size, dim_size)

        # all average, bias of user, bias of item
        self.m = np.sum(train["overall"]) / self.train_size
        self.bu = np.random.rand(self.user_size)
        self.bi = np.random.rand(self.item_size)

        # ALS trainを {userid : ([itemid, itemid, ...], [overall, overall, ...]),  ...} の形にする

        ALS_train_user = {}
        for userid, itemid, overall in zip(train["reviewerID"], train["asin"],
                                           train["overall"]):
            if userid not in ALS_train_user:
                ALS_train_user[userid] = ([], [])
            ALS_train_user[userid][0].append(itemid)
            ALS_train_user[userid][1].append(overall)

        ALS_train_item = {}
        for userid, itemid, overall in zip(train["reviewerID"], train["asin"],
                                           train["overall"]):
            if itemid not in ALS_train_item:
                ALS_train_item[itemid] = ([], [])
            ALS_train_item[itemid][0].append(userid)
            ALS_train_item[itemid][1].append(overall)

        del train

        for epoch in range(epoch_size):
            train_losses = []
            for i, items in ALS_train_user.items():  # ユーザをfixしてアイテムのみ学習
                idx = items[0]
                overalls = np.array(items[1])
                preds = np.sum(self.U[i] * self.I[idx],
                               axis=1) + self.m + self.bu[i] + self.bi[idx]
                errors = (overalls - preds)
                train_losses += errors.tolist()
                self.I[idx] = self.I[idx] + lr * (np.matmul(
                    errors.reshape(-1, 1), self.U[i].reshape(1, -1)) -
                                                  r * self.I[idx])
                self.bi[idx] = self.bi[idx] + lr * (errors - r * self.bi[idx])

            for j, users in ALS_train_item.items():  # アイテムをfixしてユーザのみ学習
                idx = users[0]
                overalls = np.array(users[1])
                preds = np.sum(self.I[j] * self.U[idx],
                               axis=1) + self.m + self.bu[idx] + self.bi[j]
                errors = (overalls - preds)
                train_losses += errors.tolist()
                self.U[idx] = self.U[idx] + lr * (np.matmul(
                    errors.reshape(-1, 1), self.I[j].reshape(1, -1)) -
                                                  r * self.U[idx])
                self.bu[idx] = self.bu[idx] + lr * (errors - r * self.bu[idx])

            self.train_losses.append(mse(train_losses))

            # validation
            valid_losses = []
            for i, j, overall in zip(valid["reviewerID"], valid["asin"],
                                     valid["overall"]):
                pred = np.dot(self.U[i],
                              self.I[j]) + self.m + self.bu[i] + self.bi[j]
                valid_losses.append(pred - overall)

            self.valid_losses.append(mse(valid_losses))
            if self.valid_losses[-1] != self.valid_losses[-1]:
                print("epoch ", epoch + 1,
                      " Invalid : overflow or zero division is occuring")
                return 0

            # describe
            if (epoch + 1) % desc_step == 0:
                self.desc(epoch)

            # early stopping
            if epoch != 0 and self.valid_losses[-2] < self.valid_losses[
                    -1] and early_stopping:
                self.desc(epoch)
                return 0
        self.desc(epoch_size - 1)
示例#2
0
 def evaluate(self):
     loss = mse(self.test_losses)
     print("     test loss is : ", loss)
     return loss
示例#3
0
def main():
    # Metric path
    metric_path = os.getcwd() + '/utils/metric'
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))

    eng = matlab.engine.connect_matlab()
    names = matlab.engine.find_matlab()
    print('matlab process name: ', names)
    eng.addpath(metric_path)

    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate()

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                scores = 0.0
                imrmse = 0.0
                avg_pirm_rmse = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_sr_img_y = bgr2ycbcr(cropped_sr_img, only_y=True)
                    cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True)
                    avg_psnr += util.calculate_psnr(cropped_sr_img_y * 255,
                                                    cropped_gt_img_y * 255)
                    immse = util.mse(cropped_sr_img_y * 255,
                                     cropped_gt_img_y * 255)
                    avg_pirm_rmse += immse
                    scores += eng.calc_NIQE(save_img_path, 4)

                avg_psnr = avg_psnr / idx
                scores = scores / idx
                avg_pirm_rmse = math.sqrt(avg_pirm_rmse / idx)

                # log
                logger.info(
                    '# Validation # PSNR: {:.4e}, NIQE: {:.4e}, pirm_rmse: {:.4e}'
                    .format(avg_psnr, scores, avg_pirm_rmse))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}, NIQE: {:.4e}, pirm_rmse: {:.4e}'
                    .format(epoch, current_step, avg_psnr, scores,
                            avg_pirm_rmse))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('NIQE', scores, current_step)
                    tb_logger.add_scalar('pirm_rmse', avg_pirm_rmse,
                                         current_step)

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
示例#4
0
 def evaluate(self):
     loss = mse(self.errors)
     print("test loss is : ", loss)
     return loss
    def train(self, train, valid, dim_size, r, lr, epoch_size, desc_step,
              early_stopping):
        # LDA Part
        words_list = [val for val in self.words_by_item.values()]
        words_dictionary = corpora.Dictionary(words_list)
        corpus = [words_dictionary.doc2bow(words) for words in words_list]
        lda_model = LdaModel(corpus=corpus,
                             num_topics=dim_size,
                             id2word=words_dictionary,
                             random_state=1)
        self.I = np.zeros((self.item_size, dim_size))
        for id, text in self.words_by_item.items():
            bow = words_dictionary.doc2bow(text)
            topics = lda_model[bow]
            temp = np.zeros(dim_size)
            for t in topics:
                temp[t[0]] = t[1]
            self.I[id] = temp
        ######################################################################

        # param config
        self.train_size = len(train["asin"])
        self.valid_size = len(valid["asin"])

        # data preparation
        # user feature and item feature
        self.U = np.random.rand(self.user_size, dim_size)
        self.I = np.random.rand(self.item_size, dim_size)

        # all average, bias of user, bias of item
        self.m = np.sum(train["overall"]) / self.train_size
        self.bu = np.random.rand(self.user_size)
        self.bi = np.random.rand(self.item_size)

        # ALS trainを {userid : ([itemid, itemid, ...], [overall, overall, ...]),  ...} の形にする
        ALS_train_user = {}
        for userid, itemid, overall in zip(train["reviewerID"], train["asin"],
                                           train["overall"]):
            if userid not in ALS_train_user:
                ALS_train_user[userid] = ([], [])
            ALS_train_user[userid][0].append(itemid)
            ALS_train_user[userid][1].append(overall)

        ALS_train_item = {}
        for userid, itemid, overall in zip(train["reviewerID"], train["asin"],
                                           train["overall"]):
            if itemid not in ALS_train_item:
                ALS_train_item[itemid] = ([], [])
            ALS_train_item[itemid][0].append(userid)
            ALS_train_item[itemid][1].append(overall)
        del train

        for epoch in range(epoch_size):
            train_losses = []
            for i, items in ALS_train_user.items():  # ユーザをfixしてアイテムのみ学習
                idx = items[0]
                overalls = np.array(items[1])
                preds = np.sum(self.U[i] * self.I[idx],
                               axis=1) + self.m + self.bu[i] + self.bi[idx]
                errors = (overalls - preds)
                train_losses += errors.tolist()
                # self.I[idx] = self.I[idx] + lr * (np.matmul(errors.reshape(-1, 1), self.U[i].reshape(1, -1)) - r * self.I[idx])
                self.bi[idx] = self.bi[idx] + lr * (errors - r * self.bi[idx])

            for j, users in ALS_train_item.items():  # アイテムをfixしてユーザのみ学習
                idx = users[0]
                overalls = np.array(users[1])
                preds = np.sum(self.I[j] * self.U[idx],
                               axis=1) + self.m + self.bu[idx] + self.bi[j]
                errors = (overalls - preds)
                train_losses += errors.tolist()
                self.U[idx] = self.U[idx] + lr * (np.matmul(
                    errors.reshape(-1, 1), self.I[j].reshape(1, -1)) -
                                                  r * self.U[idx])
                self.bu[idx] = self.bu[idx] + lr * (errors - r * self.bu[idx])

            self.train_losses.append(mse(train_losses))

            # validation
            valid_losses = []
            for i, j, overall in zip(valid["reviewerID"], valid["asin"],
                                     valid["overall"]):
                pred = np.dot(self.U[i],
                              self.I[j]) + self.m + self.bu[i] + self.bi[j]
                valid_losses.append(pred - overall)

            self.valid_losses.append(mse(valid_losses))
            if self.valid_losses[-1] != self.valid_losses[-1]:
                print("epoch ", epoch + 1,
                      " Invalid : overflow or zero division is occuring")
                return 0

            # describe
            if (epoch + 1) % desc_step == 0:
                self.desc(epoch)

            # early stopping
            if epoch != 0 and self.valid_losses[-2] < self.valid_losses[
                    -1] and early_stopping:
                self.desc(epoch)
                return 0
        self.desc(epoch_size - 1)