Пример #1
0
def load_model(load_path):
    print('Loading model and optimizer states with metadata...')
    checkpoint = torch.load(load_path)

    model = HourglassNetwork(num_channels=checkpoint['args_dict']['channels'],
                             num_stacks=checkpoint['args_dict']['stacks'],
                             num_classes=checkpoint['args_dict']['joints'],
                             input_shape=(checkpoint['args_dict']['input_dim'],
                                          checkpoint['args_dict']['input_dim'],
                                          3))
    device = torch.device(checkpoint['args_dict']['device'])
    model = model.to(device).double()
    optimizer = Adam(model.parameters(), checkpoint['args_dict']['lr'])

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    print('DONE')

    return device, model, optimizer, checkpoint['args_dict']
Пример #2
0
    def load(cls, path: Path, device='cpu'):
        device = torch.device(device)
        tmp_dir = Path(str(path) + '_load_tmp')
        shutil.unpack_archive(str(path), extract_dir=tmp_dir)
        cfg = Config.load(tmp_dir / 'config.yaml')

        tacotron = Tacotron.from_config(cfg).to(device)
        state_dict = torch.load(tmp_dir / 'tacotron.pyt', device)
        tacotron.load_state_dict(state_dict, strict=False)

        gan = GAN.from_config(cfg).to(device)
        state_dict = torch.load(tmp_dir / 'gan.pyt', device)
        gan.load_state_dict(state_dict, strict=False)

        taco_opti = Adam(tacotron.parameters())
        state_dict = torch.load(tmp_dir / 'taco_opti.pyt', device)
        taco_opti.load_state_dict(state_dict)

        gen_opti = Adam(gan.generator.parameters())
        state_dict = torch.load(tmp_dir / 'gen_opti.pyt', device)
        gen_opti.load_state_dict(state_dict)

        disc_opti = Adam(gan.discriminator.parameters())
        state_dict = torch.load(tmp_dir / 'disc_opti.pyt', device)
        disc_opti.load_state_dict(state_dict)

        model_package = ModelPackage(tacotron=tacotron,
                                     gan=gan,
                                     taco_opti=taco_opti,
                                     gen_opti=gen_opti,
                                     disc_opti=disc_opti,
                                     cfg=cfg)
        shutil.rmtree(tmp_dir)

        return model_package
Пример #3
0
def main(args):
    """Train/ Cross validate for data source = YogiDB."""
    # Create data loader
    """Generic(data.Dataset)(image_set, annotations,
                     is_train=True, inp_res=256, out_res=64, sigma=1,
                     scale_factor=0, rot_factor=0, label_type='Gaussian',
                     rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV)."""
    annotations_source = 'basic-thresholder'

    # Get the data from yogi
    db_obj = YogiDB(config.db_url)
    imageset = db_obj.get_filtered(ImageSet,
                                   name=args.image_set_name)
    annotations = db_obj.get_annotations(image_set_name=args.image_set_name,
                                         annotation_source=annotations_source)
    pts = torch.Tensor(annotations[0]['joint_self'])
    num_classes = pts.size(0)
    crop_size = 512
    if args.crop:
        crop_size = args.crop
        crop = True
    else:
        crop = False

    # Using the default RGB mean and std dev as 0
    RGB_MEAN = torch.as_tensor([0.0, 0.0, 0.0])
    RGB_STDDEV = torch.as_tensor([0.0, 0.0, 0.0])

    dataset = Generic(image_set=imageset,
                      inp_res=args.inp_res,
                      out_res=args.out_res,
                      annotations=annotations,
                      mode=args.mode,
                      crop=crop, crop_size=crop_size,
                      rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV)

    train_dataset = dataset
    train_dataset.is_train = True
    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch, shuffle=True,
                              num_workers=args.workers, pin_memory=True)

    val_dataset = dataset
    val_dataset.is_train = False
    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch, shuffle=False,
                            num_workers=args.workers, pin_memory=True)

    # Select the hardware device to use for inference.
    if torch.cuda.is_available():
        device = torch.device('cuda', torch.cuda.current_device())
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # Disable gradient calculations by default.
    torch.set_grad_enabled(False)

    # create checkpoint dir
    os.makedirs(args.checkpoint, exist_ok=True)

    if args.arch == 'hg1':
        model = hg1(pretrained=False, num_classes=num_classes)
    elif args.arch == 'hg2':
        model = hg2(pretrained=False, num_classes=num_classes)
    elif args.arch == 'hg8':
        model = hg8(pretrained=False, num_classes=num_classes)
    else:
        raise Exception('unrecognised model architecture: ' + args.model)

    model = DataParallel(model).to(device)

    if args.optimizer == "Adam":
        optimizer = Adam(model.parameters(),
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=args.weight_decay)
    else:
        optimizer = RMSprop(model.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    best_acc = 0

    # optionally resume from a checkpoint
    title = args.data_identifier + ' ' + args.arch
    if args.resume:
        assert os.path.isfile(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(args.resume, checkpoint['epoch']))
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # train for one epoch
        train_loss, train_acc = do_training_epoch(train_loader, model, device, optimizer)

        # evaluate on validation set
        if args.debug == 1:
            valid_loss, valid_acc, predictions, validation_log = do_validation_epoch(val_loader, model, device, False, True, os.path.join(args.checkpoint, 'debug.csv'), epoch + 1)
        else:
            valid_loss, valid_acc, predictions, _ = do_validation_epoch(val_loader, model, device, False)

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
        }, predictions, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
Пример #4
0
class ChatSpaceTrainer:
    def __init__(
        self,
        config,
        model: ChatSpaceModel,
        vocab: Vocab,
        device: torch.device,
        train_corpus_path,
        eval_corpus_path=None,
        encoding="utf-8",
    ):
        self.config = config
        self.device = device
        self.model = model
        self.optimizer = Adam(self.model.parameters(),
                              lr=config["learning_rate"])
        self.criterion = nn.NLLLoss()
        self.vocab = vocab
        self.encoding = encoding

        self.train_corpus = DynamicCorpus(train_corpus_path,
                                          repeat=True,
                                          encoding=self.encoding)
        self.train_dataset = ChatSpaceDataset(config,
                                              self.train_corpus,
                                              self.vocab,
                                              with_random_space=True)

        if eval_corpus_path is not None:
            self.eval_corpus = DynamicCorpus(eval_corpus_path,
                                             encoding=self.encoding)
            self.eval_dataset = ChatSpaceDataset(self.config,
                                                 eval_corpus_path,
                                                 self.vocab,
                                                 with_random_space=True)

        self.global_epochs = 0
        self.global_steps = 0

    def eval(self, batch_size=64):
        self.model.eval()

        with torch.no_grad():
            eval_output = self.run_epoch(self.eval_dataset,
                                         batch_size=batch_size,
                                         is_train=False)

        self.model.train()
        return eval_output

    def train(self, epochs=10, batch_size=64):
        for epoch_id in range(epochs):
            self.run_epoch(
                self.train_dataset,
                batch_size=batch_size,
                epoch_id=epoch_id,
                is_train=True,
                log_freq=self.config["logging_step"],
            )
            self.save_checkpoint(
                f"outputs/checkpoints/checkpoint_ep{epoch_id}.cpt")
            self.save_model(f"outputs/models/chatspace_ep{epoch_id}.pt")
            self.save_model(f"outputs/jit_models/chatspace_ep{epoch_id}.pt",
                            as_jit=False)

    def run_epoch(self,
                  dataset,
                  batch_size=64,
                  epoch_id=0,
                  is_train=True,
                  log_freq=100):
        step_outputs, step_metrics, step_inputs = [], [], []
        collect_fn = (ChatSpaceDataset.train_collect_fn
                      if is_train else ChatSpaceDataset.eval_collect_fn)
        data_loader = DataLoader(dataset, batch_size, collate_fn=collect_fn)
        for step_num, batch in enumerate(data_loader):
            batch = {
                key: value.to(self.device)
                for key, value in batch.items()
            }
            output = self.step(step_num, batch)

            if is_train:
                self.update(output["loss"])

            if not is_train or step_num % log_freq == 0:
                batch = {
                    key: value.cpu().numpy()
                    for key, value in batch.items()
                }
                output = {
                    key: value.detach().cpu().numpy()
                    for key, value in output.items()
                }

                metric = self.step_metric(output["output"], batch,
                                          output["loss"])

                if is_train:
                    print(
                        f"EPOCH:{epoch_id}",
                        f"STEP:{step_num}/{len(data_loader)}",
                        [(key + ":" + "%.3f" % metric[key]) for key in metric],
                    )
                else:
                    step_outputs.append(output)
                    step_metrics.append(metric)
                    step_inputs.append(batch)

        if not is_train:
            return self.epoch_metric(step_inputs, step_outputs, step_metrics)

        if is_train:
            self.global_epochs += 1

    def epoch_metric(self, step_inputs, step_outputs, step_metrics):
        average_loss = np.mean([metric["loss"] for metric in step_metrics])

        epoch_inputs = [
            example for step_input in step_inputs
            for example in step_input["input"].tolist()
        ]
        epoch_outputs = [
            example for output in step_outputs
            for example in output["output"].argmax(axis=-1).tolist()
        ]
        epoch_labels = [
            example for step_input in step_inputs
            for example in step_input["label"].tolist()
        ]

        epoch_metric = calculated_metric(batch_input=epoch_inputs,
                                         batch_output=epoch_outputs,
                                         batch_label=epoch_labels)

        epoch_metric["loss"] = average_loss
        return epoch_metric

    def step_metric(self, output, batch, loss=None):
        metric = calculated_metric(
            batch_input=batch["input"].tolist(),
            batch_output=output.argmax(axis=-1).tolist(),
            batch_label=batch["label"].tolist(),
        )

        if loss is not None:
            metric["loss"] = loss
        return metric

    def step(self, step_num, batch, with_loss=True, is_train=True):
        output = self.model.forward(batch["input"], batch["length"])
        if is_train:
            self.global_steps += 1

        if not with_loss:
            return {"output": output}

        loss = self.criterion(output.transpose(1, 2), batch["label"])
        return {"loss": loss, "output": output}

    def update(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def save_model(self, path, as_jit=False):
        self.optimizer.zero_grad()
        params = [{
            "param": param,
            "require_grad": param.requires_grad
        } for param in self.model.parameters()]

        for param in params:
            param["param"].require_grad = False

        with torch.no_grad():
            if not as_jit:
                torch.save(self.model.state_dict(), path)
            else:
                self.model.cpu().eval()

                sample_texts = ["오늘 너무 재밌지 않았어?", "너랑 하루종일 놀아서 기분이 좋았어!"]
                dataset = ChatSpaceDataset(self.config,
                                           sample_texts,
                                           self.vocab,
                                           with_random_space=False)
                data_loader = DataLoader(dataset,
                                         batch_size=2,
                                         collate_fn=dataset.eval_collect_fn)

                for batch in data_loader:
                    model_input = (batch["input"].detach(),
                                   batch["length"].detach())
                    traced_model = torch.jit.trace(self.model, model_input)
                    torch.jit.save(traced_model, path)
                    break

                self.model.to(self.device).train()

        print(f"Model Saved on {path}{' as_jit' if as_jit else ''}")

        for param in params:
            if param["require_grad"]:
                param["param"].require_grad = True

    def save_checkpoint(self, path):
        torch.save(
            {
                "epoch": self.global_epochs,
                "steps": self.global_steps,
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
            },
            path,
        )

    def load_checkpoint(self, path):
        checkpoint = torch.load(path)
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.global_epochs = checkpoint["epoch"]
        self.global_steps = checkpoint["steps"]

    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path))
Пример #5
0
class Trainer:
    def __init__(self, args, data_loader):
        self.args = args
        self.data_loader = data_loader
        self.metric = PSNR()

        if args.is_perceptual_oriented:
            self.lr = args.p_lr
            self.content_loss_factor = args.p_content_loss_factor
            self.perceptual_loss_factor = args.p_perceptual_loss_factor
            self.adversarial_loss_factor = args.p_adversarial_loss_factor
            self.decay_iter = args.p_decay_iter
        else:
            self.lr = args.g_lr
            self.content_loss_factor = args.g_content_loss_factor
            self.perceptual_loss_factor = args.g_perceptual_loss_factor
            self.adversarial_loss_factor = args.g_adversarial_loss_factor
            self.decay_iter = args.g_decay_iter

        self.build_model(args)
        self.build_optimizer(args)
        if args.fp16: self.initialize_model_opt_fp16()
        if args.distributed: self.parallelize_model()
        self.history = {
            n: []
            for n in [
                'adversarial_loss', 'discriminator_loss', 'perceptual_loss',
                'content_loss', 'generator_loss', 'score'
            ]
        }
        if args.load: self.load_model(args)
        if args.resume: self.resume(args)
        self.build_scheduler(args)
        print(':D')

    def train(self, args):
        total_step = len(self.data_loader)
        adversarial_criterion = nn.BCEWithLogitsLoss().cuda()
        content_criterion = nn.L1Loss().cuda()
        perception_criterion = PerceptualLoss().cuda()
        self.best_score = -9999.
        self.generator.train()
        self.discriminator.train()

        print(f"{'epoch':>7s}"
              f"{'batch':>10s}"
              f"{'discr.':>10s}"
              f"{'gener.':>10s}"
              f"{'adver.':>10s}"
              f"{'percp.':>10s}"
              f"{'contn.':>10s}"
              f"{'PSNR':>10s}"
              f"")

        for epoch in range(args.epoch, args.num_epoch):
            sample_dir_epoch = Path(
                args.checkpoint_dir) / 'sample_dir' / str(epoch)
            sample_dir_epoch.mkdir(exist_ok=True, parents=True)

            for step, image in enumerate(self.data_loader):
                low_resolution = image['lr'].cuda()
                high_resolution = image['hr'].cuda()

                real_labels = torch.ones((high_resolution.size(0), 1)).cuda()
                fake_labels = torch.zeros((high_resolution.size(0), 1)).cuda()

                ##########################
                #   training generator   #
                ##########################
                self.optimizer_generator.zero_grad()
                fake_high_resolution = self.generator(low_resolution)

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution)
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(
                    discriminator_rf, fake_labels)
                adversarial_loss_fr = adversarial_criterion(
                    discriminator_fr, real_labels)
                adversarial_loss = (adversarial_loss_fr +
                                    adversarial_loss_rf) / 2

                perceptual_loss = perception_criterion(high_resolution,
                                                       fake_high_resolution)
                content_loss = content_criterion(fake_high_resolution,
                                                 high_resolution)

                generator_loss = (
                    adversarial_loss * self.adversarial_loss_factor +
                    perceptual_loss * self.perceptual_loss_factor +
                    content_loss * self.content_loss_factor)

                if args.fp16:
                    with apex.amp.scale_loss(
                            generator_loss,
                            self.optimizer_generator) as scaled_loss:
                        scaled_loss.backward()
                else:
                    generator_loss.backward()

                self.optimizer_generator.step()

                ##########################
                # training discriminator #
                ##########################

                self.optimizer_discriminator.zero_grad()

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution.detach())
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(
                    discriminator_rf, real_labels)
                adversarial_loss_fr = adversarial_criterion(
                    discriminator_fr, fake_labels)
                discriminator_loss = (adversarial_loss_fr +
                                      adversarial_loss_rf) / 2

                if args.fp16:
                    with apex.amp.scale_loss(
                            discriminator_loss,
                            self.optimizer_discriminator) as scaled_loss:
                        scaled_loss.backward()
                else:
                    discriminator_loss.backward()

                self.optimizer_discriminator.step()

                for _ in range(self.n_unit_scheduler_step):
                    self.lr_scheduler_generator.step()
                    self.lr_scheduler_discriminator.step()
                    self.unit_scheduler_step += 1

                score = self.metric(fake_high_resolution.detach(),
                                    high_resolution)
                print(
                    f'\r'
                    f"{epoch:>3d}:{args.num_epoch:<3d}"
                    f"{step:>5d}:{total_step:<4d}"
                    f"{discriminator_loss.item():>10.4f}"
                    f"{generator_loss.item():>10.4f}"
                    f"{adversarial_loss.item()*self.adversarial_loss_factor:>10.4f}"
                    f"{perceptual_loss.item()*self.perceptual_loss_factor:>10.4f}"
                    f"{content_loss.item()*self.content_loss_factor:>10.4f}"
                    f"{score.item():>10.4f}",
                    end='')

                if step % 1000 == 0:
                    if step % 5000 == 0:
                        result = torch.cat(
                            (high_resolution, fake_high_resolution), 2)
                        save_image(result, sample_dir_epoch / f"SR_{step}.png")

            self.history['adversarial_loss'].append(
                adversarial_loss.item() * self.adversarial_loss_factor)
            self.history['discriminator_loss'].append(
                discriminator_loss.item())
            self.history['perceptual_loss'].append(perceptual_loss.item() *
                                                   self.perceptual_loss_factor)
            self.history['content_loss'].append(content_loss.item() *
                                                self.content_loss_factor)
            self.history['generator_loss'].append(generator_loss.item())
            self.history['score'].append(score.item())

            self.save(epoch, 'last.pth')
            if score > self.best_score:
                self.best_score = score
                self.save(epoch, 'best.pth')

    def build_model(self, args):
        self.generator = ESRGAN(3, 3, 64,
                                scale_factor=args.scale_factor).cuda()
        self.discriminator = Discriminator().cuda()

    def build_optimizer(self, args):
        self.optimizer_generator = Adam(self.generator.parameters(),
                                        lr=self.lr,
                                        betas=(args.b1, args.b2),
                                        weight_decay=args.weight_decay)
        self.optimizer_discriminator = Adam(self.discriminator.parameters(),
                                            lr=self.lr,
                                            betas=(args.b1, args.b2),
                                            weight_decay=args.weight_decay)

    def initialize_model_opt_fp16(self):
        self.generator, self.optimizer_generator = apex.amp.initialize(
            self.generator, self.optimizer_generator, opt_level='O2')
        self.discriminator, self.optimizer_discriminator = apex.amp.initialize(
            self.discriminator, self.optimizer_discriminator, opt_level='O2')

    def parallelize_model(self):
        self.generator = apex.parallel.DistributedDataParallel(
            self.generator, delay_allreduce=True)
        self.discriminator = apex.parallel.DistributedDataParallel(
            self.discriminator, delay_allreduce=True)

    def build_scheduler(self, args):
        if not hasattr(self, 'unit_scheduler_step'):
            self.unit_scheduler_step = -1
        self.n_unit_scheduler_step = (args.batch_size // 16) * args.nodes
        print(f'Batch size: {args.batch_size}. '
              f'Number of nodes: {args.nodes}. '
              f'Each step here equates to {self.n_unit_scheduler_step} '
              f'unit scheduler step in the paper.\n'
              f'Current unit scheduler step: {self.unit_scheduler_step}.')
        self.lr_scheduler_generator = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_generator,
            milestones=self.decay_iter,
            gamma=.5,
            last_epoch=self.unit_scheduler_step if args.resume else -1)
        self.lr_scheduler_discriminator = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_discriminator,
            milestones=self.decay_iter,
            gamma=.5,
            last_epoch=self.unit_scheduler_step if args.resume else -1)

    def load_model(self, args):
        path_to_load = Path(args.load)
        if path_to_load.is_file():
            cpt = torch.load(path_to_load,
                             map_location=lambda storage, loc: storage.cuda())
            g_sdict = cpt['g_state_dict']
            d_sdict = cpt['d_state_dict']
            if g_sdict is not None:
                if args.distributed == False:
                    g_sdict = {
                        k[7:] if k.startswith('module.') else k: v
                        for k, v in g_sdict.items()
                    }
                self.generator.load_state_dict(g_sdict)
                print(f'[*] Loading generator from {path_to_load}')
            if d_sdict is not None:
                if args.distributed == False:
                    d_sdict = {
                        k[7:] if k.startswith('module.') else k: v
                        for k, v in d_sdict.items()
                    }
                self.discriminator.load_state_dict(d_sdict)
                print(f'[*] Loading discriminator from {path_to_load}')
            if args.fp16 and cpt['amp'] is not None:
                apex.amp.load_state_dict(cpt['amp'])
        else:
            print(f'[!] No checkpoint found at {path_to_load}')

    def resume(self, args):
        path_to_resume = Path(args.resume)
        if path_to_resume.is_file():
            cpt = torch.load(path_to_resume,
                             map_location=lambda storage, loc: storage.cuda())
            if cpt['epoch'] is not None: args.epoch = cpt['epoch'] + 1
            if cpt['unit_scheduler_step'] is not None:
                self.unit_scheduler_step = cpt['unit_scheduler_step'] + 1
            if cpt['history'] is not None: self.history = cpt['history']
            g_sdict, d_sdict = cpt['g_state_dict'], cpt['d_state_dict']
            optg_sdict = cpt['opt_g_state_dict']
            optd_sdict = cpt['opt_d_state_dict']
            if g_sdict is not None:
                if args.distributed == False:
                    g_sdict = {
                        k[7:] if k.startswith('module.') else k: v
                        for k, v in g_sdict.items()
                    }
                self.generator.load_state_dict(g_sdict)
                print(f'[*] Loading generator from {path_to_resume}')
            if d_sdict is not None:
                if args.distributed == False:
                    d_sdict = {
                        k[7:] if k.startswith('module.') else k: v
                        for k, v in d_sdict.items()
                    }
                self.discriminator.load_state_dict(d_sdict)
                print(f'[*] Loading discriminator from {path_to_resume}')
            if optg_sdict is not None:
                self.optimizer_generator.load_state_dict(optg_sdict)
                print(f'[*] Loading generator optmizer from {path_to_resume}')
            if optd_sdict is not None:
                self.optimizer_discriminator.load_state_dict(optd_sdict)
                print(f'[*] Loading discriminator optmizer '
                      f'from {path_to_resume}')
            if args.fp16 and cpt['amp'] is not None:
                apex.amp.load_state_dict(cpt['amp'])
        else:
            raise ValueError(
                f'[!] No checkpoint to resume from at {path_to_resume}')

    def save(self, epoch, filename):
        g_sdict = self.generator.state_dict()
        d_sdict = self.discriminator.state_dict()
        if self.args.distributed == False:
            g_sdict = {f'module.{k}': v for k, v in g_sdict.items()}
            d_sdict = {f'module.{k}': v for k, v in d_sdict.items()}
        save_dict = {
            'epoch': epoch,
            'unit_scheduler_step': self.unit_scheduler_step,
            'history': self.history,
            'g_state_dict': g_sdict,
            'd_state_dict': d_sdict,
            'opt_g_state_dict': self.optimizer_generator.state_dict(),
            'opt_d_state_dict': self.optimizer_discriminator.state_dict(),
            'amp': apex.amp.state_dict() if self.args.fp16 else None,
            'args': self.args
        }
        torch.save(save_dict, Path(self.args.checkpoint_dir) / filename)
Пример #6
0
class NeuralNetworks(nn.Module, ObservableData):
    '''
    Neural Networks.

    References:
        - Kamyshanska, H., & Memisevic, R. (2014). The potential energy of an autoencoder. IEEE transactions on pattern analysis and machine intelligence, 37(6), 1261-1273.
    '''

    # `bool` that means initialization in this class will be deferred or not.
    __init_deferred_flag = False

    def __init__(
        self,
        computable_loss,
        initializer_f=None,
        optimizer_f=None,
        learning_rate=1e-05,
        units_list=[100, 1],
        dropout_rate_list=[0.0, 0.5],
        activation_list=[
            torch.nn.functional.tanh, torch.nn.functional.sigmoid
        ],
        hidden_batch_norm_list=[100, None],
        ctx="cpu",
        regularizatable_data_list=[],
        scale=1.0,
        output_no_bias_flag=False,
        all_no_bias_flag=False,
        not_init_flag=False,
    ):
        '''
        Init.

        Args:
            computable_loss:                is-a `ComputableLoss` or `nn.modules.loss._Loss`.
            initializer_f:                  A function that contains `torch.nn.init`.
                                            This function receive `tensor` as input and output initialized `tensor`. 
                                            If `None`, it is drawing from the Xavier distribution.

            optimizer_f:                    A function that contains `torch.optim.optimizer.Optimizer` for parameters of model.
                                            This function receive `self.parameters()` as input and output `torch.optim.optimizer.Optimizer`.

            learning_rate:                  `float` of learning rate.
            units_list:                     `list` of int` of the number of units in hidden/output layers.
            dropout_rate_list:              `list` of `float` of dropout rate.
            activation_list:                `list` of act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in input gate.
            hidden_batch_norm_list:         `list` of `mxnet.gluon.nn.BatchNorm`.
            ctx:                            Context-manager that changes the selected device.
            regularizatable_data_list:           `list` of `RegularizatableData`.
            scale:                          `float` of scaling factor for initial parameters.
            output_no_bias_flag:            `bool` for using bias or not in output layer(last hidden layer).
            all_no_bias_flag:               `bool` for using bias or not in all layer.
            not_init_flag:                  `bool` of whether initialize parameters or not.
        '''
        super(NeuralNetworks, self).__init__()

        if isinstance(computable_loss, ComputableLoss) is False and isinstance(
                computable_loss, nn.modules.loss._Loss) is False:
            raise TypeError(
                "The type of `computable_loss` must be `ComputableLoss` or `nn.modules.loss._Loss`."
            )

        if len(units_list) != len(activation_list):
            raise ValueError(
                "The length of `units_list` and `activation_list` must be equivalent."
            )
        self.__units_list = units_list

        if len(dropout_rate_list) != len(units_list):
            raise ValueError(
                "The length of `dropout_rate_list` and `activation_list` must be equivalent."
            )

        self.initializer_f = initializer_f
        self.optimizer_f = optimizer_f
        self.__units_list = units_list
        self.__all_no_bias_flag = all_no_bias_flag
        self.__output_no_bias_flag = output_no_bias_flag

        self.dropout_forward_list = [None] * len(dropout_rate_list)
        for i in range(len(dropout_rate_list)):
            self.dropout_forward_list[i] = nn.Dropout(p=dropout_rate_list[i])
        self.dropout_forward_list = nn.ModuleList(self.dropout_forward_list)

        self.hidden_batch_norm_list = [None] * len(hidden_batch_norm_list)
        for i in range(len(hidden_batch_norm_list)):
            if hidden_batch_norm_list[i] is not None:
                if isinstance(hidden_batch_norm_list[i], int) is True:
                    self.hidden_batch_norm_list[i] = nn.BatchNorm1d(
                        hidden_batch_norm_list[i])
                else:
                    self.hidden_batch_norm_list[i] = hidden_batch_norm_list[i]

        self.hidden_batch_norm_list = nn.ModuleList(
            self.hidden_batch_norm_list)

        self.__not_init_flag = not_init_flag
        self.activation_list = activation_list

        self.__computable_loss = computable_loss
        self.__learning_rate = learning_rate

        for v in regularizatable_data_list:
            if isinstance(v, RegularizatableData) is False:
                raise TypeError(
                    "The type of values of `regularizatable_data_list` must be `RegularizatableData`."
                )
        self.__regularizatable_data_list = regularizatable_data_list

        self.__ctx = ctx

        self.fc_list = []
        self.flatten = nn.Flatten()

        self.epoch = 0
        self.__loss_list = []
        logger = getLogger("accelbrainbase")
        self.__logger = logger
        self.__input_dim = None

    def initialize_params(self, input_dim):
        '''
        Initialize params.

        Args:
            input_dim:      The number of units in input layer.
        '''
        if self.__input_dim is not None:
            return
        self.__input_dim = input_dim

        if len(self.fc_list) > 0:
            return

        if self.__all_no_bias_flag is True:
            use_bias = False
        else:
            use_bias = True

        fc = nn.Linear(input_dim, self.__units_list[0], bias=use_bias)
        if self.initializer_f is None:
            fc.weight = torch.nn.init.xavier_normal_(fc.weight, gain=1.0)
        else:
            fc.weight = self.initializer_f(fc.weight)

        fc_list = [fc]

        for i in range(1, len(self.__units_list)):
            if self.__all_no_bias_flag is True:
                use_bias = False
            elif self.__output_no_bias_flag is True and i + 1 == len(
                    self.__units_list):
                use_bias = False
            else:
                use_bias = True

            fc = nn.Linear(self.__units_list[i - 1],
                           self.__units_list[i],
                           bias=use_bias)

            if self.initializer_f is None:
                fc.weight = torch.nn.init.xavier_normal_(fc.weight, gain=1.0)
            else:
                fc.weight = self.initializer_f(fc.weight)

            fc_list.append(fc)

        self.fc_list = nn.ModuleList(fc_list)
        self.to(self.__ctx)

        if self.init_deferred_flag is False:
            if self.__not_init_flag is False:
                if self.optimizer_f is None:
                    self.optimizer = Adam(
                        self.parameters(),
                        lr=self.__learning_rate,
                    )
                else:
                    self.optimizer = self.optimizer_f(self.parameters(), )

    def learn(self, iteratable_data):
        '''
        Learn samples drawn by `IteratableData.generate_learned_samples()`.

        Args:
            iteratable_data:     is-a `IteratableData`.
        '''
        if isinstance(iteratable_data, IteratableData) is False:
            raise TypeError(
                "The type of `iteratable_data` must be `IteratableData`.")

        self.__loss_list = []
        learning_rate = self.__learning_rate
        try:
            epoch = self.epoch
            iter_n = 0
            for batch_observed_arr, batch_target_arr, test_batch_observed_arr, test_batch_target_arr in iteratable_data.generate_learned_samples(
            ):
                self.initialize_params(
                    input_dim=self.flatten(batch_observed_arr).shape[-1])
                self.optimizer.zero_grad()
                # rank-3
                pred_arr = self.inference(batch_observed_arr)
                loss = self.compute_loss(pred_arr, batch_target_arr)
                loss.backward()
                self.optimizer.step()
                self.regularize()

                if (iter_n + 1) % int(
                        iteratable_data.iter_n / iteratable_data.epochs) == 0:
                    with torch.inference_mode():
                        # rank-3
                        test_pred_arr = self.inference(test_batch_observed_arr)

                        test_loss = self.compute_loss(test_pred_arr,
                                                      test_batch_target_arr)
                    _loss = loss.to('cpu').detach().numpy().copy()
                    _test_loss = test_loss.to('cpu').detach().numpy().copy()
                    self.__loss_list.append((_loss, _test_loss))
                    self.__logger.debug("Epochs: " + str(epoch + 1) +
                                        " Train loss: " + str(_loss) +
                                        " Test loss: " + str(_test_loss))
                    epoch += 1
                iter_n += 1

        except KeyboardInterrupt:
            self.__logger.debug("Interrupt.")

        self.epoch = epoch
        self.__logger.debug("end. ")

    def inference(self, observed_arr):
        '''
        Inference samples drawn by `IteratableData.generate_inferenced_samples()`.

        Args:
            observed_arr:   rank-2 Array like or sparse matrix as the observed data points.
                            The shape is: (batch size, feature points)

        Returns:
            `tensor` of inferenced feature points.
        '''
        return self(observed_arr)

    def compute_loss(self, pred_arr, labeled_arr):
        '''
        Compute loss.

        Args:
            pred_arr:       `mxnet.ndarray` or `mxnet.symbol`.
            labeled_arr:    `mxnet.ndarray` or `mxnet.symbol`.

        Returns:
            loss.
        '''
        return self.__computable_loss(pred_arr, labeled_arr)

    def regularize(self):
        '''
        Regularization.
        '''
        if len(self.__regularizatable_data_list) > 0:
            params_dict = self.extract_learned_dict()
            for regularizatable in self.__regularizatable_data_list:
                params_dict = regularizatable.regularize(params_dict)

            for k, params in params_dict.items():
                self.load_state_dict({k: params}, strict=False)

    def extract_learned_dict(self):
        '''
        Extract (pre-) learned parameters.

        Returns:
            `dict` of the parameters.
        '''
        params_dict = {}
        for k in self.state_dict().keys():
            params_dict.setdefault(k, self.state_dict()[k])

        return params_dict

    def forward(self, x):
        '''
        Forward with torch.

        Args:
            x:      `tensor` of observed data points.
        
        Returns:
            `tensor` of inferenced feature points.
        '''
        x = self.flatten(x)
        self.initialize_params(input_dim=x.shape[-1])
        for i in range(len(self.activation_list)):
            x = self.fc_list[i](x)

            if self.activation_list[i] == "identity_adjusted":
                x = x / torch.sum(torch.ones_like(x))
            elif self.activation_list[i] == "softmax":
                x = F.softmax(x)
            elif self.activation_list[i] == "log_softmax":
                x = F.log_softmax(x)
            elif self.activation_list[i] != "identity":
                x = self.activation_list[i](x)

            if self.dropout_forward_list[i] is not None:
                x = self.dropout_forward_list[i](x)
            if self.hidden_batch_norm_list[i] is not None:
                x = self.hidden_batch_norm_list[i](x)

        return x

    def save_parameters(self, filename):
        '''
        Save parameters to files.

        Args:
            filename:       File name.
        '''
        torch.save(
            {
                'epoch': self.epoch,
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': self.loss_arr,
                'input_dim': self.__input_dim,
            }, filename)

    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        checkpoint = torch.load(filename)
        self.initialize_params(input_dim=checkpoint["input_dim"])
        self.load_state_dict(checkpoint['model_state_dict'], strict=strict)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epoch = checkpoint['epoch']
        self.__loss_list = checkpoint['loss'].tolist()
        if ctx is not None:
            self.to(ctx)
            self.__ctx = ctx

    def set_readonly(self, value):
        ''' setter '''
        raise TypeError("This property must be read-only.")

    def get_loss_arr(self):
        ''' getter for losses. '''
        return np.array(self.__loss_list)

    loss_arr = property(get_loss_arr, set_readonly)

    def get_init_deferred_flag(self):
        ''' getter for `bool` that means initialization in this class will be deferred or not. '''
        return self.__init_deferred_flag

    def set_init_deferred_flag(self, value):
        ''' setter for `bool` that means initialization in this class will be deferred or not. '''
        self.__init_deferred_flag = value

    init_deferred_flag = property(get_init_deferred_flag,
                                  set_init_deferred_flag)

    def get_units_list(self):
        ''' getter for `list` of units in each layer. '''
        return self.__units_list

    units_list = property(get_units_list, set_readonly)
Пример #7
0
class Trainer:
    def __init__(self, config, data_loader_train, data_loader_val, device):
        self.device = device
        self.num_epoch = config["num_epoch"]
        self.start_epoch = config["start_epoch"]
        self.image_size = config["image_size"]
        self.sample_dir = config["sample_dir"]

        self.batch_size = config["batch_size"]
        self.data_loader_train = data_loader_train
        self.data_loader_val = data_loader_val
        self.num_res_blocks = config["num_rrdn_blocks"]
        self.nf = config["nf"]
        self.scale_factor = config["scale_factor"]
        self.is_psnr_oriented = config["is_psnr_oriented"]
        self.load_previous_opt = config["load_previous_opt"]

        if self.is_psnr_oriented:
            self.lr = config["p_lr"]
            self.content_loss_factor = config["p_content_loss_factor"]
            self.perceptual_loss_factor = config["p_perceptual_loss_factor"]
            self.adversarial_loss_factor = config["p_adversarial_loss_factor"]
            self.decay_iter = config["p_decay_iter"]
        else:
            self.lr = config["g_lr"]
            self.content_loss_factor = config["g_content_loss_factor"]
            self.perceptual_loss_factor = config["g_perceptual_loss_factor"]
            self.adversarial_loss_factor = config["g_adversarial_loss_factor"]
            self.decay_iter = config["g_decay_iter"]

        self.metrics = {
            "dis_loss": [],
            "gen_loss": [],
            "per_loss": [],
            "con_loss": [],
            "adv_loss": [],
            "SSIM": [],  # validation set per epoch
            "PSNR": [],  # validation set per epoch
        }

        self.build_model(config)
        self.lr_scheduler_generator = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_generator, self.decay_iter)
        self.lr_scheduler_discriminator = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_discriminator, self.decay_iter)

    def train(self):
        os.makedirs("/content/drive/MyDrive/Project-ESRGAN", exist_ok=True)

        total_step = len(self.data_loader_train)
        adversarial_criterion = nn.BCEWithLogitsLoss().to(self.device)
        content_criterion = nn.L1Loss().to(self.device)
        perception_criterion = PerceptualLoss().to(self.device)

        Tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
        ) else torch.Tensor

        # FID score
        FID_DIM = 768
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FID_DIM]
        fid_model = InceptionV3([block_idx])
        fid_model.to(self.device)

        best_val_psnr = 0.0
        best_val_ssim = 0.0
        best_val_fid = np.inf

        # FOR BICUBIC
        already_calculated_bicubic = False
        upsampler = torch.nn.Upsample(scale_factor=4, mode="bicubic")

        for epoch in range(self.start_epoch,
                           self.start_epoch + self.num_epoch):
            self.generator.train()
            self.discriminator.train()

            epoch_gen_loss = []
            epoch_dis_loss = []
            epoch_per_loss = []
            epoch_adv_loss = []
            epoch_con_loss = []

            SAVE = False

            training_loader_iter = iter(self.data_loader_train)
            length_train = len(training_loader_iter)

            if not os.path.exists(os.path.join(self.sample_dir, str(epoch))):
                os.makedirs(os.path.join(self.sample_dir, str(epoch)))

            for step in tqdm(
                    range(length_train),
                    desc=
                    f"Epoch: {epoch}/{self.start_epoch + self.num_epoch-1}",
            ):
                image = next(training_loader_iter)
                # print("step", step)
                low_resolution = image["lr"].to(self.device)
                high_resolution = image["hr"].to(self.device)

                # Adversarial ground truths
                real_labels = Variable(
                    Tensor(
                        np.ones((low_resolution.size(0),
                                 *self.discriminator.output_shape))),
                    requires_grad=False,
                )
                fake_labels = Variable(
                    Tensor(
                        np.zeros((low_resolution.size(0),
                                  *self.discriminator.output_shape))),
                    requires_grad=False,
                )

                ##########################
                #   training generator   #
                ##########################
                self.optimizer_generator.zero_grad()

                with amp.autocast():
                    fake_high_resolution = self.generator(low_resolution)

                    # Content loss - L1 loss - psnr oriented
                    content_loss = content_criterion(fake_high_resolution,
                                                     high_resolution)

                    if not self.is_psnr_oriented:
                        for p in self.discriminator.parameters():
                            p.requires_grad = False

                        # Extract validity predictions from discriminator
                        score_real = self.discriminator(
                            high_resolution).detach()
                        score_fake = self.discriminator(fake_high_resolution)

                        # ----------------------
                        # calculate Realtivistic GAN loss Drf and Dfr
                        discriminator_rf = score_real - score_fake.mean(
                            0, keepdim=True)
                        discriminator_fr = score_fake - score_real.mean(
                            0, keepdim=True)

                        adversarial_loss_rf = adversarial_criterion(
                            discriminator_rf, fake_labels)
                        adversarial_loss_fr = adversarial_criterion(
                            discriminator_fr, real_labels)
                        adversarial_loss = (adversarial_loss_fr +
                                            adversarial_loss_rf) / 2
                        # ----------------------

                        # Perceptual loss - VGG loss before activations
                        perceptual_loss = perception_criterion(
                            high_resolution, fake_high_resolution)

                        generator_loss = (
                            perceptual_loss * self.perceptual_loss_factor +
                            adversarial_loss * self.adversarial_loss_factor +
                            content_loss * self.content_loss_factor)

                    else:
                        generator_loss = content_loss * self.content_loss_factor

                self.scaler_gen.scale(generator_loss).backward()
                self.scaler_gen.step(self.optimizer_generator)

                scale_gen = self.scaler_gen.get_scale()
                self.scaler_gen.update()
                skip_gen_lr_sched = scale_gen != self.scaler_gen.get_scale()
                # self.optimizer_generator.step()

                self.metrics["gen_loss"].append(
                    np.round(generator_loss.detach().item(), 5))
                self.metrics["con_loss"].append(
                    np.round(
                        content_loss.detach().item() *
                        self.content_loss_factor, 4))

                epoch_gen_loss.append(self.metrics["gen_loss"][-1])
                epoch_con_loss.append(self.metrics["con_loss"][-1])

                ##########################
                # training discriminator #
                ##########################
                if not self.is_psnr_oriented:
                    self.optimizer_discriminator.zero_grad()
                    for p in self.discriminator.parameters():
                        p.requires_grad = True

                    with amp.autocast():
                        score_real = self.discriminator(high_resolution)

                        score_fake = self.discriminator(
                            fake_high_resolution).detach()

                        # real
                        discriminator_rf = score_real - score_fake.mean(
                            axis=0, keepdim=True)
                        adversarial_loss_rf = (adversarial_criterion(
                            discriminator_rf, real_labels) * 0.5)

                        # fake
                        score_fake = self.discriminator(
                            fake_high_resolution.detach())

                        discriminator_fr = score_fake - score_real.mean(
                            axis=0, keepdim=True)
                        adversarial_loss_fr = (adversarial_criterion(
                            discriminator_fr, fake_labels) * 0.5)

                        # score_real = self.discriminator(high_resolution)
                        # score_fake = self.discriminator(fake_high_resolution.detach())
                        # discriminator_rf = score_real - score_fake.mean(
                        #     axis=0, keepdim=True
                        # )

                        # adversarial_loss_rf = adversarial_criterion(
                        #     discriminator_rf, real_labels
                        # )

                        # discriminator_loss = (
                        #     adversarial_loss_fr + adversarial_loss_rf
                        # ) / 2

                    self.scaler_dis.scale(adversarial_loss_rf).backward(
                        retain_graph=True)
                    self.scaler_dis.scale(adversarial_loss_fr).backward()
                    self.scaler_dis.step(self.optimizer_discriminator)
                    dis_scale_val = self.scaler_dis.get_scale()
                    self.scaler_dis.update()
                    skip_dis_lr_sched = dis_scale_val != self.scaler_dis.get_scale(
                    )

                    discriminator_loss = (adversarial_loss_rf.detach().item() +
                                          adversarial_loss_fr.detach().item())

                    self.metrics["dis_loss"].append(
                        np.round(discriminator_loss, 5))

                    # generator metrics
                    self.metrics["adv_loss"].append(
                        np.round(
                            adversarial_loss.detach().item() *
                            self.adversarial_loss_factor,
                            4,
                        ))
                    self.metrics["per_loss"].append(
                        np.round(
                            perceptual_loss.detach().item() *
                            self.perceptual_loss_factor,
                            4,
                        ))

                    epoch_dis_loss.append(self.metrics["dis_loss"][-1])
                    epoch_adv_loss.append(self.metrics["adv_loss"][-1])
                    epoch_per_loss.append(self.metrics["per_loss"][-1])

                torch.cuda.empty_cache()
                gc.collect()

                if step == int(total_step /
                               2) or step == 0 or step == (total_step - 1):
                    if not self.is_psnr_oriented:
                        print(
                            f"[Epoch {epoch}/{self.start_epoch+self.num_epoch-1}] [Batch {step+1}/{total_step}]"
                            f"[D loss {self.metrics['dis_loss'][-1]}] [G loss {self.metrics['gen_loss'][-1]}]"
                            f"[perceptual loss {self.metrics['per_loss'][-1]}]"
                            f"[adversarial loss {self.metrics['adv_loss'][-1]}]"
                            f"[content loss {self.metrics['con_loss'][-1]}]"
                            f"")
                    else:
                        print(
                            f"[Epoch {epoch}/{self.start_epoch+self.num_epoch-1}] [Batch {step+1}/{total_step}] "
                            f"[content loss {self.metrics['con_loss'][-1]}]")

                    result = torch.cat(
                        (
                            denormalize(high_resolution.detach().cpu()),
                            denormalize(
                                upsampler(low_resolution).detach().cpu()),
                            denormalize(fake_high_resolution.detach().cpu()),
                        ),
                        2,
                    )

                    # print(result[0][:, 512:, :].min(), result[0][:, 512:, :].max())

                    save_image(
                        result,
                        os.path.join(self.sample_dir, str(epoch),
                                     f"ESR_{step+1}.png"),
                        nrow=8,
                        normalize=False,
                    )
                    wandb.log({
                        f"training_images_ESR_{step+1}":
                        wandb.Image(
                            os.path.join(self.sample_dir, str(epoch),
                                         f"ESR_{step+1}.png"))
                    })

                torch.cuda.empty_cache()
                gc.collect()

            # epoch metrics
            if not self.is_psnr_oriented:
                print(
                    f"Epoch: {epoch} -> Dis loss: {np.round(np.array(epoch_dis_loss).mean(), 4)} "
                    f"Gen loss: {np.round(np.array(epoch_gen_loss).mean(), 4)} "
                    f"Per loss:: {np.round(np.array(epoch_per_loss).mean(), 4)} "
                    f"Adv loss:: {np.round(np.array(epoch_adv_loss).mean(), 4)} "
                    f"Con loss:: {np.round(np.array(epoch_con_loss).mean(), 4)}"
                    f"")
                wandb.log({
                    "epoch":
                    epoch,
                    "Dis_loss":
                    np.round(np.array(epoch_dis_loss).mean(), 4),
                    "Gen_loss":
                    np.round(np.array(epoch_gen_loss).mean(), 4),
                    "Con_loss":
                    np.round(np.array(epoch_con_loss).mean(), 4),
                    "Per_loss":
                    np.round(np.array(epoch_per_loss).mean(), 4),
                    "Adv_loss":
                    np.round(np.array(epoch_adv_loss).mean(), 4),
                    "Con_loss":
                    np.round(np.array(epoch_con_loss).mean(), 4),
                })
            else:
                print(
                    f"Epoch: {epoch} -> "
                    f"Gen loss: {np.round(np.array(epoch_gen_loss).mean(), 4)} "
                    f"Con loss:: {np.round(np.array(epoch_con_loss).mean(), 4)}"
                    f"")
                wandb.log({
                    "epoch":
                    epoch,
                    "Gen_loss":
                    np.round(np.array(epoch_gen_loss).mean(), 4),
                    "Con_loss":
                    np.round(np.array(epoch_con_loss).mean(), 4),
                })

            if not skip_gen_lr_sched:
                self.lr_scheduler_generator.step()

            if not self.is_psnr_oriented:
                if not skip_dis_lr_sched:
                    self.lr_scheduler_discriminator.step()

            # validation set SSIM and PSNR
            val_batch_psnr = []
            val_batch_ssim = []
            val_batch_FID = []

            if not already_calculated_bicubic:
                ups_batch_psnr = []
                ups_batch_ssim = []
                ups_batch_FID = []

            for idx, image_val in enumerate(self.data_loader_val):
                val_low_resolution = image_val["lr"].to(self.device)
                val_high_resolution = image_val["hr"].to(self.device)

                self.generator.eval()

                with torch.no_grad():
                    with amp.autocast():
                        val_fake_high_res = self.generator(
                            val_low_resolution).detach()

                        # generated image metrics FID, PSNR, SSIM
                        val_fid = cal_fretchet(
                            val_high_resolution.detach(),
                            val_fake_high_res,
                            fid_model,
                            dims=FID_DIM,
                        )

                    val_psnr, val_ssim = cal_img_metrics(
                        val_fake_high_res,
                        val_high_resolution,
                    )

                    val_batch_psnr.append(val_psnr)
                    val_batch_ssim.append(val_ssim)
                    val_batch_FID.append(val_fid)

                    # bicubic image metrics FID, PSNR, SSIM
                    if not already_calculated_bicubic:
                        with amp.autocast():
                            ups_fid = cal_fretchet(
                                val_high_resolution,
                                val_low_resolution.detach(),
                                fid_model,
                                FID_DIM,
                            )

                        ups_psnr, ups_ssim = cal_img_metrics(
                            upsampler(val_low_resolution.cpu()),
                            val_high_resolution.cpu(),
                        )

                        ups_batch_psnr.append(ups_psnr)
                        ups_batch_ssim.append(ups_ssim)
                        ups_batch_FID.append(ups_fid)

                # visualization
                result_val = torch.cat(
                    (
                        denormalize(val_high_resolution.detach().cpu()),
                        denormalize(
                            upsampler(val_low_resolution).detach().cpu()),
                        denormalize(val_fake_high_res.detach().cpu()),
                    ),
                    2,
                )
                save_image(
                    result_val,
                    os.path.join(self.sample_dir,
                                 f"Validation_{epoch}_{idx}.png"),
                    nrow=8,
                    normalize=False,
                )

            val_epoch_psnr = round(
                sum(val_batch_psnr) / len(val_batch_psnr), 4)
            val_epoch_ssim = round(
                sum(val_batch_ssim) / len(val_batch_ssim), 4)
            val_epoch_fid = round(sum(val_batch_FID) / len(val_batch_FID), 4)

            if not already_calculated_bicubic:
                ups_epoch_psnr = round(
                    sum(ups_batch_psnr) / len(ups_batch_psnr), 4)
                ups_epoch_ssim = round(
                    sum(ups_batch_ssim) / len(ups_batch_ssim), 4)
                ups_epoch_FID = round(
                    sum(ups_batch_FID) / len(ups_batch_FID), 4)
                already_calculated_bicubic = True

            # log validation psnr, ssim
            wandb.log({
                "epoch": epoch,
                "valid_psnr": val_epoch_psnr,
                "valid_ssim": val_epoch_ssim,
                "valid_fid": val_epoch_fid,
            })
            # log validation image 0
            wandb.log({
                "validation_images_1":
                wandb.Image(
                    os.path.join(self.sample_dir,
                                 f"Validation_{epoch}_0.png")),
                "validation_images_2":
                wandb.Image(
                    os.path.join(self.sample_dir,
                                 f"Validation_{epoch}_1.png")),
                "validation_images_3":
                wandb.Image(
                    os.path.join(self.sample_dir, f"Validation_{epoch}_2.png"))
            })

            if val_epoch_fid < best_val_fid:
                best_val_fid = val_epoch_fid
                SAVE = True

            if val_epoch_psnr > best_val_psnr:
                best_val_psnr = val_epoch_psnr
                SAVE = True

            if val_epoch_ssim > best_val_ssim:
                best_val_ssim = val_epoch_ssim
                SAVE = True

            self.metrics["PSNR"].append(val_epoch_psnr)
            self.metrics["SSIM"].append(val_epoch_ssim)

            print(
                f"Validation Set: PSNR: {val_epoch_psnr}, SSIM: {val_epoch_ssim}, FID: {val_epoch_fid}"
            )
            print(
                f"Bicubic Ups: PSNR: {ups_epoch_psnr}, SSIM: {ups_epoch_ssim}, FID: {ups_epoch_FID}"
            )

            del (
                val_fid,
                val_psnr,
                val_ssim,
                val_epoch_psnr,
                val_epoch_ssim,
                val_epoch_fid,
                val_low_resolution,
                val_fake_high_res,
                val_high_resolution,
            )

            torch.cuda.empty_cache()
            gc.collect()

            models_dict = {
                "next_epoch": epoch + 1,
                f"generator_dict": self.generator.state_dict(),
                f"optim_gen": self.optimizer_generator.state_dict(),
                f"grad_scaler_gen": self.scaler_gen.state_dict(),
                f"metrics": self.metrics,
            }

            if not self.is_psnr_oriented:
                models_dict[
                    f"discriminator_dict"] = self.discriminator.state_dict()
                models_dict[
                    f"optim_dis"] = self.optimizer_discriminator.state_dict()
                models_dict[f"grad_scaler_dis"] = self.scaler_dis.state_dict()

            save_name = f"checkpoint_{epoch}.tar"

            if SAVE:
                # remove all previous best checkpoints
                _ = [
                    os.remove(os.path.join(r"/content/", file))
                    for file in os.listdir(r"/content/")
                    if file.startswith("best_")
                ]

                _ = [
                    os.remove(
                        os.path.join(r"/content/drive/MyDrive/Project-ESRGAN",
                                     file)) for file in
                    os.listdir(r"/content/drive/MyDrive/Project-ESRGAN")
                    if file.startswith("best_")
                ]
                save_name = f"best_checkpoint_{epoch}.tar"
                print(
                    f"Best val scores  till epoch {epoch} -> PSNR: {best_val_psnr}, SSIM: {best_val_ssim}, , FID: {best_val_fid}"
                )

            if save_name.startswith("checkpoint"):
                _ = [
                    os.remove(os.path.join(r"/content/", file))
                    for file in os.listdir(r"/content/")
                    if file.startswith("checkpoint")
                ]
                _ = [
                    os.remove(
                        os.path.join(r"/content/drive/MyDrive/Project-ESRGAN",
                                     file)) for file in
                    os.listdir(r"/content/drive/MyDrive/Project-ESRGAN")
                    if file.startswith("checkpoint")
                ]

            torch.save(models_dict, save_name)
            shutil.copyfile(
                save_name,
                os.path.join(r"/content/drive/MyDrive/Project-ESRGAN",
                             save_name),
            )

            torch.cuda.empty_cache()
            gc.collect()

        wandb.run.summary["valid_psnr"] = best_val_psnr
        wandb.run.summary["valid_ssim"] = best_val_ssim
        wandb.run.summary["valid_fid"] = best_val_fid

        return self.metrics

    def build_model(self, config):

        self.generator = Generator(
            channels=3,
            nf=self.nf,
            num_res_blocks=self.num_res_blocks,
            scale=self.scale_factor,
        ).to(self.device)

        self.generator._mrsa_init(self.generator.layers_)

        self.discriminator = Discriminator(input_shape=(3, self.image_size,
                                                        self.image_size)).to(
                                                            self.device)

        self.optimizer_generator = Adam(
            self.generator.parameters(),
            lr=self.lr,
            betas=(config["b1"], config["b2"]),
            weight_decay=config["weight_decay"],
        )
        self.optimizer_discriminator = Adam(
            self.discriminator.parameters(),
            lr=0.0004,
            betas=(config["b1"], config["b2"]),
            weight_decay=config["weight_decay"],
        )

        self.scaler_gen = torch.cuda.amp.GradScaler()
        self.scaler_dis = torch.cuda.amp.GradScaler()

        self.load_model()

    def load_model(self, ):
        drive_path = r"/content/drive/MyDrive/Project-ESRGAN"
        print(f"[*] Finding checkpoint {self.start_epoch-1} in {drive_path}")

        checkpoint_file = f"checkpoint_{self.start_epoch-1}.tar"
        checkpoint_path = os.path.join(drive_path, checkpoint_file)
        if not os.path.exists(checkpoint_path):
            print(f"[!] No checkpoint for epoch {self.start_epoch -1}")
            return

        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.generator.load_state_dict(checkpoint[f"generator_dict"])
        print("Generator weights loaded.")

        if self.load_previous_opt:
            self.optimizer_generator.load_state_dict(checkpoint[f"optim_gen"])
            print("Generator Optimizer state loaded")

            self.scaler_gen.load_state_dict(checkpoint[f"grad_scaler_gen"])
            print("Grad Scaler - Generator loaded")

            try:
                self.discriminator.load_state_dict(
                    checkpoint[f"discriminator_dict"])
                print("Discriminator weights loaded.")
                self.optimizer_discriminator.load_state_dict(
                    checkpoint[f"optim_dis"])
                print("Discriminator optimizer loaded.")

                self.scaler_dis.load_state_dict(checkpoint[f"grad_scaler_dis"])
                print("Grad Scaler - Discriminator loaded")
            except:
                pass

        self.metrics["dis_loss"] = checkpoint[f"metrics"]["dis_loss"]
        self.metrics["gen_loss"] = checkpoint[f"metrics"]["gen_loss"]
        self.metrics["per_loss"] = checkpoint[f"metrics"]["per_loss"]
        self.metrics["con_loss"] = checkpoint[f"metrics"]["con_loss"]
        self.metrics["adv_loss"] = checkpoint[f"metrics"]["adv_loss"]
        self.metrics["PSNR"] = checkpoint[f"metrics"]["PSNR"]
        self.metrics["SSIM"] = checkpoint[f"metrics"]["SSIM"]
        self.start_epoch = checkpoint["next_epoch"]

        temp = []

        if self.decay_iter:
            self.decay_iter = np.array(self.decay_iter) - self.start_epoch

            for i in self.decay_iter:
                if i > 0:
                    temp.append(i)

        if not temp:
            temp.append(200)

        self.decay_iter = temp
        print("Decay_iter:", self.decay_iter)

        print(f"Checkpoint: {self.start_epoch-1} loaded")
Пример #8
0
class Train(object):
    def __init__(self, train_dir=None, eval_dir=None, vocab=None, vectors=None):
        self.vectors = vectors
        if vocab is None:
            self.vocab = Vocab(config.vocab_path, config.vocab_size)
        else:
            self.vocab = vocab

        print(self.vocab)
        self.batcher_train = Batcher(config.train_data_path, self.vocab, mode='train',
                                     batch_size=config.batch_size, single_pass=False)
        time.sleep(15)
        self.batcher_eval = Batcher(config.eval_data_path, self.vocab, mode='eval',
                                    batch_size=config.batch_size, single_pass=True)
        time.sleep(15)

        cur_time = int(time.time())
        if train_dir is None:
            train_dir = os.path.join(config.log_root, 'train_%d' % (cur_time))
            if not os.path.exists(train_dir):
                os.mkdir(train_dir)

        if eval_dir is None:
            eval_dir = os.path.join(config.log_root, 'eval_%s' % (cur_time))
            if not os.path.exists(eval_dir):
                os.mkdir(eval_dir)

        self.model_dir = os.path.join(train_dir, 'model')
        if not os.path.exists(self.model_dir):
            os.mkdir(self.model_dir)

        self.summary_writer_train = writer.FileWriter(train_dir)
        self.summary_writer_eval = writer.FileWriter(eval_dir)

    def setup_train(self, model_file_path=None):
        self.model = Model(model_file_path, vectors=self.vectors)

        params = list(self.model.encoder.parameters()) + list(self.model.decoder.parameters()) + \
                 list(self.model.reduce_state.parameters())

        pytorch_total_params = sum(p.numel() for p in params if p.requires_grad)
        print(f"Parameters count: {pytorch_total_params}")

        initial_lr = config.lr_coverage if config.is_coverage else config.lr
        # self.optimizer = adagrad.Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)
        self.optimizer = Adam(params, lr=initial_lr)
        start_iter, start_training_loss, start_eval_loss = 0, 0, 0

        if model_file_path is not None:
            state = torch.load(model_file_path, map_location=lambda storage, location: storage)
            start_iter = state['iter']
            start_training_loss = state['current_train_loss']
            start_eval_loss = state['current_eval_loss']

            if not config.is_coverage:
                self.optimizer.load_state_dict(state['optimizer'])
                if use_cuda:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            print(k)
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()

        self.chechpoint = Checkpoint(self.model,
                                     self.optimizer,
                                     self.model_dir,
                                     start_eval_loss if start_eval_loss != 0 else float("inf"))

        return start_iter, start_training_loss, start_eval_loss

    def model_batch_step(self, batch, eval):

        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch, use_cuda)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch, use_cuda)

        encoder_outputs, encoder_feature, encoder_hidden = self.model.encoder(enc_batch, enc_lens)
        s_t_1 = self.model.reduce_state(encoder_hidden)

        step_losses = []
        step_decoded_idx = []
        for di in range(min(max_dec_len, config.max_dec_steps)):
            y_t_1 = dec_batch[:, di]  # Teacher forcing

            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = \
                self.model.decoder(y_t_1, s_t_1,
                                   encoder_outputs,
                                   encoder_feature,
                                   enc_padding_mask, c_t_1,
                                   extra_zeros,
                                   enc_batch_extend_vocab,
                                   coverage, di)

            if eval:
                _, top_idx = final_dist.topk(1)
                step_decoded_idx.append(top_idx)

            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var
        loss = torch.mean(batch_avg_loss)

        final_decoded_sentences = None
        if eval:
            final_decoded_sentences = torch.stack(step_decoded_idx, 2).squeeze(1)
            print(final_decoded_sentences)

        return loss, final_decoded_sentences

    def train_one_batch(self, batch):
        self.optimizer.zero_grad()
        loss, _ = self.model_batch_step(batch, False)
        loss.backward()

        self.norm = clip_grad_norm_(self.model.encoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.decoder.parameters(), config.max_grad_norm)
        clip_grad_norm_(self.model.reduce_state.parameters(), config.max_grad_norm)

        self.optimizer.step()

        return loss.item()

    def run_eval(self):

        self.model.eval()
        batch = self.batcher_eval.next_batch()
        iter = 0
        start = time.time()
        running_avg_loss = 0
        with torch.no_grad():
            while batch is not None:
                loss, _ = self.model_batch_step(batch, False)
                loss = loss.item()
                running_avg_loss = calc_running_avg_loss(loss, running_avg_loss)
                batch = self.batcher_eval.next_batch()

                iter += 1
                if iter % config.print_interval == 0:
                    print('Eval steps %d, seconds for %d batch: %.2f , loss: %f' % (
                        iter, config.print_interval, time.time() - start, running_avg_loss))
                    start = time.time()

        return running_avg_loss

    def trainIters(self, n_iters, model_file_path=None):
        iter, running_avg_loss_train, running_avg_loss_eval = self.setup_train(model_file_path)
        start = time.time()

        loss_train = 0
        while iter < n_iters:

            self.model.train()
            batch = self.batcher_train.next_batch()
            loss_train = self.train_one_batch(batch)
            running_avg_loss_train = calc_and_write_running_avg_loss(loss_train,
                                                                     "running_avg_loss_train",
                                                                     running_avg_loss_train,
                                                                     self.summary_writer_train,
                                                                     iter)
            iter += 1

            if iter % 100 == 0:
                self.summary_writer_train.flush()

            if iter % config.print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f, loss: %f, avg_loss: %f' % (iter, config.print_interval,
                                                                                        time.time() - start,
                                                                                        loss_train,
                                                                                        running_avg_loss_train))
                start = time.time()

            if iter % 5000 == 0:
                running_avg_loss_eval = self.run_eval()
                write_summary("running_avg_loss_eval",
                              running_avg_loss_eval,
                              self.summary_writer_eval,
                              iter)
                self.summary_writer_eval.flush()
                self.chechpoint.check_loss(running_avg_loss_eval, running_avg_loss_train, iter)
                start = time.time()
                self.batcher_eval.start_threads()

            if config.is_coverage and iter % 2000 == 0:
                 self.chechpoint.save_model("coverage", running_avg_loss_eval, running_avg_loss_train, iter)
            if iter % 10000 == 0:
                self.chechpoint.save_model("critical", running_avg_loss_eval, running_avg_loss_train, iter)
Пример #9
0
def main(args):
    print('===> Configuration')
    print(args)

    os.makedirs(args.save, exist_ok=True)
    with open(os.path.join(args.save, "config.txt"), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    cudnn.benchmark = True if args.cuda else False
    device = torch.device("cuda" if args.cuda else "cpu")

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    # MNIST dataset normalized between [0, 1]
    try:
        with open(args.dataset, 'rb') as f:
            dataset_dict = pickle.load(f)
    except BaseException as e:
        print(str(e.__class__.__name__) + ": " + str(e))
        exit()

    X_train_labeled = dataset_dict["X_train_labeled"]
    y_train_labeled = dataset_dict["y_train_labeled"]
    X_train_unlabeled = dataset_dict["X_train_unlabeled"]
    y_train_unlabeled = dataset_dict["y_train_unlabeled"]
    X_val = dataset_dict["X_val"]
    y_val = dataset_dict["y_val"]
    X_test = dataset_dict["X_test"]
    y_test = dataset_dict["y_test"]

    labeled_dataset = TensorDataset(
        torch.from_numpy(X_train_labeled).float(),
        torch.from_numpy(y_train_labeled).long())
    unlabeled_dataset = TensorDataset(
        torch.from_numpy(X_train_unlabeled).float(),
        torch.from_numpy(y_train_unlabeled).long())
    val_dataset = TensorDataset(
        torch.from_numpy(X_val).float(),
        torch.from_numpy(y_val).long())
    test_dataset = TensorDataset(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(y_test).long())

    NUM_SAMPLES = len(labeled_dataset) + len(unlabeled_dataset)
    NUM_LABELED = len(labeled_dataset)

    labeled_dataloader = DataLoader(labeled_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=args.workers,
                                    drop_last=False)
    unlabeled_dataloader = DataLoader(unlabeled_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.workers,
                                      drop_last=False)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers)

    alpha = args.eta * NUM_SAMPLES / NUM_LABELED
    tau = CosineAnnealing(start=1.0, stop=0.5, t_max=args.tw, mode='down')

    model = MnistViVA(z_dim=args.z_dim,
                      hidden_dim=args.hidden,
                      zeta=args.zeta,
                      rho=args.rho,
                      device=device).to(device)
    optimizer = Adam(model.parameters())

    best_val_epoch = 0
    best_val_loss = sys.float_info.max
    best_val_acc = 0.0
    test_acc = 0.0
    early_stop_counter = 0

    if args.resume:
        if os.path.isfile(args.resume):
            print("===> Loading Checkpoint to Resume '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            best_val_epoch = checkpoint['best_epoch']
            best_val_loss = checkpoint['best_val_loss']
            best_val_acc = checkpoint['best_val_acc']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("\t===> Loaded Checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError(
                "\t====> no checkpoint found at '{}'".format(args.resume))

    n_batches = len(labeled_dataloader) + len(unlabeled_dataloader)
    n_unlabeled_per_labeled = len(unlabeled_dataloader) // len(
        labeled_dataloader) + 1

    with tqdm(range(args.start_epoch, args.epochs), desc="Epochs") as nested:
        for epoch in nested:

            # Train
            model.train()
            train_recon_loss = AverageMeter('Train_Recon_Loss')
            train_latent_loss = AverageMeter('Train_Latent_Loss')
            train_label_loss = AverageMeter('Train_Label_Loss')
            train_tsne_loss = AverageMeter('Train_tSNE_Loss')
            train_total_loss = AverageMeter('Train_Total_Loss')
            train_accuracy = AverageMeter('Train_Accuracy')

            labeled_iter = iter(labeled_dataloader)
            unlabeled_iter = iter(unlabeled_dataloader)

            for batch_idx in range(n_batches):

                is_supervised = batch_idx % n_unlabeled_per_labeled == 0
                # get batch from respective dataloader
                if is_supervised:
                    try:
                        data, target = next(labeled_iter)
                        data = data.to(device)
                        target = target.to(device)
                        one_hot_target = one_hot(target, 10)
                    except StopIteration:
                        data, target = next(unlabeled_iter)
                        data = data.to(device)
                        target = target.to(device)
                        one_hot_target = None
                else:
                    data, target = next(unlabeled_iter)
                    data = data.to(device)
                    target = target.to(device)
                    one_hot_target = None

                model.zero_grad()

                recon_loss_sum, y_logits, t_coords, latent_loss_sum, tsne_loss = model(
                    data, one_hot_target, tau.step())
                recon_loss = recon_loss_sum / data.size(0)
                label_loss = F.cross_entropy(y_logits,
                                             target,
                                             reduction='mean')
                latent_loss = latent_loss_sum / data.size(0)

                # Full loss
                total_loss = recon_loss + latent_loss + args.gamma * tsne_loss
                if is_supervised and one_hot_target is not None:
                    total_loss += alpha * label_loss

                assert not np.isnan(
                    total_loss.item()), 'Model diverged with loss = NaN'

                train_recon_loss.update(recon_loss.item())
                train_latent_loss.update(latent_loss.item())
                train_label_loss.update(label_loss.item())
                train_tsne_loss.update(tsne_loss.item())
                train_total_loss.update(total_loss.item())

                total_loss.backward()
                optimizer.step()

                pred = y_logits.argmax(
                    dim=1,
                    keepdim=True)  # get the index of the max log-probability
                train_correct = pred.eq(target.view_as(pred)).sum().item()
                train_accuracy.update(train_correct / data.size(0),
                                      data.size(0))

                if batch_idx % args.log_interval == 0:
                    tqdm.write(
                        'Train Epoch: {} [{}/{} ({:.0f}%)]\t Recon: {:.6f} Latent: {:.6f} t-SNE: {:.6f} Accuracy {:.4f} T {:.6f}'
                        .format(epoch, batch_idx, n_batches,
                                100. * batch_idx / n_batches,
                                train_recon_loss.avg, train_latent_loss.avg,
                                train_tsne_loss.avg, train_accuracy.avg,
                                tau.value))

            tqdm.write(
                '====> Epoch: {} Average train loss - Recon {:.3f} Latent {:.3f} t-SNE {:.6f} Label {:.6f} Accuracy {:.4f}'
                .format(epoch, train_recon_loss.avg, train_latent_loss.avg,
                        train_tsne_loss.avg, train_label_loss.avg,
                        train_accuracy.avg))

            # Validation
            model.eval()

            val_recon_loss = AverageMeter('Val_Recon_Loss')
            val_latent_loss = AverageMeter('Val_Latent_Loss')
            val_label_loss = AverageMeter('Val_Label_Loss')
            val_tsne_loss = AverageMeter('Val_tSNE_Loss')
            val_total_loss = AverageMeter('Val_Total_Loss')
            val_accuracy = AverageMeter('Val_Accuracy')

            with torch.no_grad():
                for i, (data, target) in enumerate(val_loader):
                    data = data.to(device)
                    target = target.to(device)

                    recon_loss_sum, y_logits, t_coords, latent_loss_sum, tsne_loss = model(
                        data, temperature=tau.value)

                    recon_loss = recon_loss_sum / data.size(0)
                    label_loss = F.cross_entropy(y_logits,
                                                 target,
                                                 reduction='mean')
                    latent_loss = latent_loss_sum / data.size(0)

                    # Full loss
                    total_loss = recon_loss + latent_loss + args.gamma * tsne_loss + alpha * label_loss

                    val_recon_loss.update(recon_loss.item())
                    val_latent_loss.update(latent_loss.item())
                    val_label_loss.update(label_loss.item())
                    val_tsne_loss.update(tsne_loss.item())
                    val_total_loss.update(total_loss.item())

                    pred = y_logits.argmax(
                        dim=1, keepdim=True
                    )  # get the index of the max log-probability
                    val_correct = pred.eq(target.view_as(pred)).sum().item()
                    val_accuracy.update(val_correct / data.size(0),
                                        data.size(0))

            tqdm.write(
                '\t Validation loss - Recon {:.3f} Latent {:.3f} t-SNE {:.6f} Label: {:.6f} Accuracy {:.4f}'
                .format(val_recon_loss.avg, val_latent_loss.avg,
                        val_tsne_loss.avg, val_label_loss.avg,
                        val_accuracy.avg))

            is_best = val_accuracy.avg > best_val_acc
            if is_best:
                early_stop_counter = 0
                best_val_epoch = epoch
                best_val_loss = val_total_loss.avg
                best_val_acc = val_accuracy.avg

                test_accuracy = AverageMeter('Test_Accuracy')
                with torch.no_grad():
                    for i, (data, target) in enumerate(test_loader):
                        data = data.to(device)
                        target = target.to(device)

                        _, y_logits, _, _, _ = model(data,
                                                     temperature=tau.value)

                        pred = y_logits.argmax(
                            dim=1, keepdim=True
                        )  # get the index of the max log-probability
                        test_correct = pred.eq(
                            target.view_as(pred)).sum().item()
                        test_accuracy.update(test_correct / data.size(0),
                                             data.size(0))

                test_acc = test_accuracy.avg
                tqdm.write('\t Test Accuracy {:.4f}'.format(test_acc))
                with open(os.path.join(args.save, 'train_result.txt'),
                          'w') as f:
                    f.write('Best Validation Epoch: {}\n'.format(epoch))
                    f.write('Train Recon Loss: {}\n'.format(
                        train_recon_loss.avg))
                    f.write('Train Latent Loss: {}\n'.format(
                        train_latent_loss.avg))
                    f.write('Train tSNE Loss: {}\n'.format(
                        train_tsne_loss.avg))
                    f.write('Train Label Loss: {}\n'.format(
                        train_label_loss.avg))
                    f.write('Train Total Loss: {}\n'.format(
                        train_total_loss.avg))
                    f.write('Train Accuracy: {}\n'.format(train_accuracy.avg))
                    f.write('Val Recon Loss: {}\n'.format(val_recon_loss.avg))
                    f.write('Val Latent Loss: {}\n'.format(
                        val_latent_loss.avg))
                    f.write('Val tSNE Loss: {}\n'.format(val_tsne_loss.avg))
                    f.write('Val Label Loss: {}\n'.format(val_label_loss.avg))
                    f.write('Val Total Loss: {}\n'.format(val_total_loss.avg))
                    f.write('Val Accuracy: {}\n'.format(val_accuracy.avg))
                    f.write('Test Accuracy: {}\n'.format(test_acc))
            else:
                early_stop_counter += 1

            save_checkpoint(
                {
                    'epoch': epoch,
                    'best_epoch': best_val_epoch,
                    'best_val_loss': best_val_loss,
                    'best_val_accuracy': best_val_acc,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                filename=os.path.join(args.save, 'checkpoint.pth'))

            if args.early_stop > 0 and early_stop_counter == args.early_stop:
                tqdm.write(
                    "Early Stop with no improvement: epoch {}".format(epoch))
                break

    print("Training is Completed!")
    print("Best Val Acc: {:.4f} Test Acc: {:.4f}".format(
        best_val_acc, test_acc))
Пример #10
0
class Brain:
    def __init__(self, state_shape, n_actions, device, n_workers, epochs,
                 n_iters, epsilon, lr):
        self.state_shape = state_shape
        self.n_actions = n_actions
        self.device = device
        self.n_workers = n_workers
        self.mini_batch_size = 32
        self.epochs = epochs
        self.n_iters = n_iters
        self.initial_epsilon = epsilon
        self.epsilon = self.initial_epsilon
        self.lr = lr

        self.current_policy = Model(self.state_shape,
                                    self.n_actions).to(self.device)

        self.optimizer = Adam(self.current_policy.parameters(),
                              lr=self.lr,
                              eps=1e-5)
        self._schedule_fn = lambda step: max(1.0 - float(step / self.n_iters),
                                             0)
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=self._schedule_fn)

    def get_actions_and_values(self, state, batch=False):
        if not batch:
            state = np.expand_dims(state, 0)
        state = from_numpy(state).byte().permute([0, 3, 1, 2]).to(self.device)
        with torch.no_grad():
            dist, value = self.current_policy(state)
            action = dist.sample()
            log_prob = dist.log_prob(action)
        return action.cpu().numpy(), value.detach().cpu().numpy().squeeze(
        ), log_prob.cpu().numpy()

    def choose_mini_batch(self, states, actions, returns, advs, values,
                          log_probs):
        for worker in range(self.n_workers):
            idxes = np.random.randint(0, states.shape[1], self.mini_batch_size)
            yield states[worker][idxes], actions[worker][idxes], returns[worker][idxes], advs[worker][idxes], \
                  values[worker][idxes], log_probs[worker][idxes]

    def train(self, states, actions, rewards, dones, values, log_probs,
              next_values):
        returns = self.get_gae(rewards, values.copy(), next_values, dones)
        values = np.vstack(
            values)  # .reshape((len(values[0]) * self.n_workers,))
        advs = returns - values
        advs = (advs - advs.mean(1).reshape((-1, 1))) / (advs.std(1).reshape(
            (-1, 1)) + 1e-8)
        for epoch in range(self.epochs):
            for state, action, q_value, adv, old_value, old_log_prob in self.choose_mini_batch(
                    states, actions, returns, advs, values, log_probs):
                state = torch.ByteTensor(state).permute([0, 3, 1,
                                                         2]).to(self.device)
                action = torch.Tensor(action).to(self.device)
                adv = torch.Tensor(adv).to(self.device)
                q_value = torch.Tensor(q_value).to(self.device)
                old_value = torch.Tensor(old_value).to(self.device)
                old_log_prob = torch.Tensor(old_log_prob).to(self.device)

                dist, value = self.current_policy(state)
                entropy = dist.entropy().mean()
                new_log_prob = self.calculate_log_probs(
                    self.current_policy, state, action)
                ratio = (new_log_prob - old_log_prob).exp()
                actor_loss = self.compute_ac_loss(ratio, adv)

                clipped_value = old_value + torch.clamp(
                    value.squeeze() - old_value, -self.epsilon, self.epsilon)
                clipped_v_loss = (clipped_value - q_value).pow(2)
                unclipped_v_loss = (value.squeeze() - q_value).pow(2)
                critic_loss = 0.5 * torch.max(clipped_v_loss,
                                              unclipped_v_loss).mean()

                total_loss = critic_loss + actor_loss - 0.01 * entropy
                self.optimize(total_loss)

        return total_loss.item(), entropy.item(), \
               explained_variance(values.reshape((len(returns[0]) * self.n_workers,)),
                                  returns.reshape((len(returns[0]) * self.n_workers,)))

    def schedule_lr(self):
        self.scheduler.step()

    def schedule_clip_range(self, iter):
        self.epsilon = max(1.0 - float(iter / self.n_iters),
                           0) * self.initial_epsilon

    def optimize(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.current_policy.parameters(), 0.5)
        self.optimizer.step()

    def get_gae(self,
                rewards,
                values,
                next_values,
                dones,
                gamma=0.99,
                lam=0.95):

        returns = [[] for _ in range(self.n_workers)]
        extended_values = np.zeros((self.n_workers, len(rewards[0]) + 1))
        for worker in range(self.n_workers):
            extended_values[worker] = np.append(values[worker],
                                                next_values[worker])
            gae = 0
            for step in reversed(range(len(rewards[worker]))):
                delta = rewards[worker][step] + \
                        gamma * (extended_values[worker][step + 1]) * (1 - dones[worker][step]) \
                        - extended_values[worker][step]
                gae = delta + gamma * lam * (1 - dones[worker][step]) * gae
                returns[worker].insert(0, gae + extended_values[worker][step])

        return np.vstack(
            returns)  # .reshape((len(returns[0]) * self.n_workers,))

    @staticmethod
    def calculate_log_probs(model, states, actions):
        policy_distribution, _ = model(states)
        return policy_distribution.log_prob(actions)

    def compute_ac_loss(self, ratio, adv):
        new_r = ratio * adv
        clamped_r = torch.clamp(ratio, 1 - self.epsilon,
                                1 + self.epsilon) * adv
        loss = torch.min(new_r, clamped_r)
        loss = -loss.mean()
        return loss

    def save_params(self, iteration, running_reward):
        torch.save(
            {
                "current_policy_state_dict": self.current_policy.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "scheduler_state_dict": self.scheduler.state_dict(),
                "iteration": iteration,
                "running_reward": running_reward,
                "clip_range": self.epsilon
            }, "params.pth")

    def load_params(self):
        checkpoint = torch.load("params.pth", map_location=self.device)
        self.current_policy.load_state_dict(
            checkpoint["current_policy_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        iteration = checkpoint["iteration"]
        running_reward = checkpoint["running_reward"]
        self.epsilon = checkpoint["clip_range"]

        return running_reward, iteration

    def set_to_eval_mode(self):
        self.current_policy.eval()
Пример #11
0
class LSTMNetworks(nn.Module, ObservableData):
    '''
    Long short term memory(LSTM) networks.
    
    Originally, Long Short-Term Memory(LSTM) networks as a 
    special RNN structure has proven stable and powerful for 
    modeling long-range dependencies.
    
    The Key point of structural expansion is its memory cell 
    which essentially acts as an accumulator of the state information. 
    Every time observed data points are given as new information and 
    input to LSTM's input gate, its information will be accumulated to 
    the cell if the input gate is activated. The past state of cell 
    could be forgotten in this process if LSTM's forget gate is on.
    Whether the latest cell output will be propagated to the final state 
    is further controlled by the output gate.
    
    References:
        - Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using RNN encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078.
        - Malhotra, P., Ramakrishnan, A., Anand, G., Vig, L., Agarwal, P., & Shroff, G. (2016). LSTM-based encoder-decoder for multi-sensor anomaly detection. arXiv preprint arXiv:1607.00148.
        - Zaremba, W., Sutskever, I., & Vinyals, O. (2014). Recurrent neural network regularization. arXiv preprint arXiv:1409.2329.

    '''

    # `bool` that means initialization in this class will be deferred or not.
    __init_deferred_flag = False

    def __init__(
        self,
        computable_loss,
        initializer_f=None,
        optimizer_f=None,
        learning_rate=1e-05,
        seq_len=None,
        hidden_n=200,
        output_n=1,
        dropout_rate=0.5,
        input_adjusted_flag=True,
        observed_activation=torch.nn.Tanh(),
        input_gate_activation=torch.nn.Sigmoid(),
        forget_gate_activation=torch.nn.Sigmoid(),
        output_gate_activation=torch.nn.Sigmoid(),
        hidden_activation=torch.nn.Tanh(),
        output_activation=torch.nn.Tanh(),
        output_layer_flag=True,
        output_no_bias_flag=False,
        output_nn=None,
        ctx="cpu",
        regularizatable_data_list=[],
        not_init_flag=False,
    ):
        '''
        Init.

        Args:
            computable_loss:                is-a `ComputableLoss` or `mxnet.gluon.loss`.
            initializer:                    is-a `mxnet.initializer.Initializer` for parameters of model. If `None`, it is drawing from the Xavier distribution.
            batch_size:                     `int` of batch size of mini-batch.
            learning_rate:                  `float` of learning rate.
            learning_attenuate_rate:        `float` of attenuate the `learning_rate` by a factor of this value every `attenuate_epoch`.
            attenuate_epoch:                `int` of attenuate the `learning_rate` by a factor of `learning_attenuate_rate` every `attenuate_epoch`.
                                            

            seq_len:                        `int` of the length of sequences.
                                            This means refereed maxinum step `t` in feedforward.
                                            If `0`, this model will reference all series elements included 
                                            in observed data points.
                                            If not `0`, only first sequence will be observed by this model 
                                            and will be feedfowarded as feature points.
                                            This parameter enables you to build this class as `Decoder` in
                                            Sequence-to-Sequence(Seq2seq) scheme.

            hidden_n:                       `int` of the number of units in hidden layer.
            output_n:                       `int` of the nuber of units in output layer.
            dropout_rate:                   `float` of dropout rate.
            input_adjusted_flag:            `bool` of flag that means this class will adjusted observed data points by normalization.
            observed_activation:            `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` 
                                            that activates observed data points.

            optimizer_name:                 `str` of name of optimizer.

            input_gate_activation:          `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in input gate.
            forget_gate_activation:         `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in forget gate.
            output_gate_activation:         `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in output gate.
            hidden_activation:              `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in hidden layer.
            output_activation:              `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in output layer.
                                            If this value is `identity`, the activation function equivalents to the identity function.

            output_layer_flag:              `bool` that means this class has output layer or not.
            output_no_bias_flag:            `bool` for using bias or not in output layer(last hidden layer).
            output_nn:               is-a `NNHybrid` as output layers.
                                            If not `None`, `output_layer_flag` and `output_no_bias_flag` will be ignored.

            ctx:                            `mx.cpu()` or `mx.gpu()`.
            hybridize_flag:                  Call `mxnet.gluon.HybridBlock.hybridize()` or not.
            regularizatable_data_list:           `list` of `Regularizatable`.
            scale:                          `float` of scaling factor for initial parameters.
        '''
        if isinstance(computable_loss, ComputableLoss) is False and isinstance(
                computable_loss, nn.modules.loss._Loss) is False:
            raise TypeError(
                "The type of `computable_loss` must be `ComputableLoss` or `gluon.loss.Loss`."
            )

        super(LSTMNetworks, self).__init__()
        self.initializer_f = initializer_f
        self.optimizer_f = optimizer_f
        self.__not_init_flag = not_init_flag

        if dropout_rate > 0.0:
            self.dropout_forward = nn.Dropout(p=dropout_rate)
        else:
            self.dropout_forward = None

        self.__observed_activation = observed_activation
        self.__input_gate_activation = input_gate_activation
        self.__forget_gate_activation = forget_gate_activation
        self.__output_gate_activation = output_gate_activation
        self.__hidden_activation = hidden_activation
        self.__output_activation = output_activation
        self.__output_layer_flag = output_layer_flag

        self.__computable_loss = computable_loss
        self.__learning_rate = learning_rate
        self.__hidden_n = hidden_n
        self.__output_n = output_n
        self.__dropout_rate = dropout_rate
        self.__input_adjusted_flag = input_adjusted_flag

        for v in regularizatable_data_list:
            if isinstance(v, RegularizatableData) is False:
                raise TypeError(
                    "The type of values of `regularizatable_data_list` must be `Regularizatable`."
                )
        self.__regularizatable_data_list = regularizatable_data_list

        self.__ctx = ctx

        logger = getLogger("accelbrainbase")
        self.__logger = logger

        self.__input_dim = None
        self.__input_seq_len = None

        self.__output_layer_flag = output_layer_flag
        self.__output_no_bias_flag = output_no_bias_flag
        self.__output_nn = output_nn
        self.seq_len = seq_len

        self.epoch = 0
        self.__loss_list = []

    def initialize_params(self, input_dim, input_seq_len):
        '''
        Initialize params.

        Args:
            input_dim:      The number of units in input layer.
        '''
        if self.__input_dim is not None:
            return
        self.__input_dim = input_dim
        self.__input_seq_len = input_seq_len

        if self.__not_init_flag is False:
            if self.init_deferred_flag is False:
                self.observed_fc = nn.Linear(
                    input_dim,
                    self.__hidden_n * 4,
                    bias=False,
                )
                if self.initializer_f is None:
                    self.observed_fc.weight = torch.nn.init.xavier_normal_(
                        self.observed_fc.weight, gain=1.0)
                else:
                    self.observed_fc.weight = self.initializer_f(
                        self.observed_fc.weight)

                self.hidden_fc = nn.Linear(
                    self.__hidden_n,
                    self.__hidden_n * 4,
                )
                if self.initializer_f is None:
                    self.hidden_fc.weight = torch.nn.init.xavier_normal_(
                        self.hidden_fc.weight, gain=1.0)
                else:
                    self.hidden_fc.weight = self.initializer_f(
                        self.observed_fc.weight)

                self.input_gate_fc = nn.Linear(
                    self.__hidden_n,
                    self.__hidden_n,
                    bias=False,
                )
                self.forget_gate_fc = nn.Linear(
                    self.__hidden_n,
                    self.__hidden_n,
                    bias=False,
                )
                self.output_gate_fc = nn.Linear(
                    self.__hidden_n,
                    self.__hidden_n,
                    bias=False,
                )

            self.output_fc = None
            self.output_nn = None
            if self.__output_layer_flag is True and self.__output_nn is None:
                if self.__output_no_bias_flag is True:
                    use_bias = False
                else:
                    use_bias = True

                # Different from mxnet version.
                self.output_fc = nn.Linear(
                    self.__hidden_n * self.__input_seq_len,
                    self.__output_n * self.__input_seq_len,
                    bias=use_bias)
                self.__output_dim = self.__output_n
            elif self.__output_nn is not None:
                self.output_nn = self.__output_nn
                self.__output_dim = self.output_nn.units_list[-1]
            else:
                self.__output_dim = self.__hidden_n

        self.to(self.__ctx)
        if self.init_deferred_flag is False:
            if self.__not_init_flag is False:
                if self.optimizer_f is None:
                    self.optimizer = Adam(
                        self.parameters(),
                        lr=self.__learning_rate,
                    )
                else:
                    self.optimizer = self.optimizer_f(self.parameters(), )

    def learn(self, iteratable_data):
        '''
        Learn the observed data points
        for vector representation of the input time-series.

        Args:
            iteratable_data:     is-a `IteratableData`.

        '''
        if isinstance(iteratable_data, IteratableData) is False:
            raise TypeError(
                "The type of `iteratable_data` must be `IteratableData`.")

        self.__loss_list = []
        learning_rate = self.__learning_rate
        try:
            epoch = self.epoch
            iter_n = 0
            for batch_observed_arr, batch_target_arr, test_batch_observed_arr, test_batch_target_arr in iteratable_data.generate_learned_samples(
            ):
                self.__batch_size = batch_observed_arr.shape[0]
                self.__seq_len = batch_observed_arr.shape[1]
                self.initialize_params(input_dim=batch_observed_arr.reshape(
                    self.__batch_size * self.__seq_len, -1).shape[-1],
                                       input_seq_len=self.__seq_len)
                if self.output_nn is not None:
                    if hasattr(self.output_nn, "optimizer") is False:
                        _ = self.inference(batch_observed_arr)

                self.optimizer.zero_grad()
                if self.output_nn is not None:
                    self.output_nn.optimizer.zero_grad()

                # rank-3
                pred_arr = self.inference(batch_observed_arr)
                loss = self.compute_loss(pred_arr, batch_target_arr)
                loss.backward()

                if self.output_nn is not None:
                    self.output_nn.optimizer.step()
                self.optimizer.step()
                self.regularize()

                if (iter_n + 1) % int(
                        iteratable_data.iter_n / iteratable_data.epochs) == 0:
                    with torch.inference_mode():
                        # rank-3
                        test_pred_arr = self.inference(test_batch_observed_arr)

                        test_loss = self.compute_loss(test_pred_arr,
                                                      test_batch_target_arr)

                    _loss = loss.to('cpu').detach().numpy().copy()
                    _test_loss = test_loss.to('cpu').detach().numpy().copy()
                    self.__loss_list.append((_loss, _test_loss))
                    self.__logger.debug("Epochs: " + str(epoch + 1) +
                                        " Train loss: " + str(_loss) +
                                        " Test loss: " + str(_test_loss))
                    epoch += 1
                iter_n += 1

        except KeyboardInterrupt:
            self.__logger.debug("Interrupt.")

        self.epoch = epoch

        self.__logger.debug("end. ")

    def inference(self, observed_arr):
        '''
        Inference the feature points to reconstruct the time-series.

        Args:
            observed_arr:           rank-3 array like or sparse matrix as the observed data points.

        Returns:
            `mxnet.ndarray` of inferenced feature points.
        '''
        return self(observed_arr)

    def compute_loss(self, pred_arr, labeled_arr):
        '''
        Compute loss.

        Args:
            pred_arr:       `mxnet.ndarray` or `mxnet.symbol`.
            labeled_arr:    `mxnet.ndarray` or `mxnet.symbol`.

        Returns:
            loss.
        '''
        return self.__computable_loss(pred_arr, labeled_arr)

    def regularize(self):
        '''
        Regularization.
        '''
        if len(self.__regularizatable_data_list) > 0:
            params_dict = self.extract_learned_dict()
            for regularizatable in self.__regularizatable_data_list:
                params_dict = regularizatable.regularize(params_dict)

            for k, params in params_dict.items():
                self.load_state_dict({k: params}, strict=False)

    def extract_learned_dict(self):
        '''
        Extract (pre-) learned parameters.

        Returns:
            `dict` of the parameters.
        '''
        params_dict = {}
        for k in self.state_dict().keys():
            params_dict.setdefault(k, self.state_dict()[k])

        return params_dict

    def extract_feature_points(self):
        '''
        Extract the activities in hidden layer and reset it, 
        considering this method will be called per one cycle in instances of time-series.

        Returns:
            The `mxnet.ndarray` of array like or sparse matrix of feature points or virtual visible observed data points.
        '''
        return self.feature_points_arr

    def forward(self, x):
        '''
        Forward with Gluon API.

        Args:
            x:      `mxnet.ndarray` of observed data points.
        
        Returns:
            `mxnet.ndarray` or `mxnet.symbol` of inferenced feature points.
        '''
        self.__batch_size = x.shape[0]
        self.__seq_len = x.shape[1]
        x = x.reshape(self.__batch_size, self.__seq_len, -1)
        self.initialize_params(input_dim=x.shape[2],
                               input_seq_len=self.__seq_len)

        hidden_activity_arr = self.hidden_forward_propagate(x)

        if self.__dropout_rate > 0:
            hidden_activity_arr = self.dropout_forward(hidden_activity_arr)
        self.feature_points_arr = hidden_activity_arr

        if self.output_nn is not None:
            pred_arr = self.output_nn(hidden_activity_arr)
            return pred_arr
        if self.__output_layer_flag is True:
            # rank-3
            pred_arr = self.output_forward_propagate(hidden_activity_arr)
            return pred_arr
        else:
            return hidden_activity_arr

    def hidden_forward_propagate(self, observed_arr):
        '''
        Forward propagation in LSTM gate.

        Args:
            F:                      `mxnet.ndarray` or `mxnet.symbol`.
            observed_arr:           rank-3 tensor of observed data points.
        
        Returns:
            Predicted data points.
        '''
        pred_arr = None

        hidden_activity_arr = torch.zeros((self.__batch_size, self.__hidden_n),
                                          dtype=torch.float32)
        hidden_activity_arr = hidden_activity_arr.to(self.__ctx)
        cec_activity_arr = torch.zeros((self.__batch_size, self.__hidden_n),
                                       dtype=torch.float32)
        cec_activity_arr = cec_activity_arr.to(self.__ctx)

        if self.seq_len is not None:
            cycle_n = self.seq_len
        else:
            cycle_n = self.__seq_len

        for cycle in range(cycle_n):
            if cycle == 0:
                if observed_arr[:, cycle:cycle + 1].shape[1] != 0:
                    hidden_activity_arr, cec_activity_arr = self.__lstm_forward(
                        observed_arr[:, cycle:cycle + 1], hidden_activity_arr,
                        cec_activity_arr)
                    skip_flag = False
                else:
                    skip_flag = True
            else:
                if observed_arr.shape[1] > 1:
                    x_arr = observed_arr[:, cycle:cycle + 1]
                else:
                    x_arr = torch.unsqueeze(pred_arr[:, -1], axis=1)

                if x_arr.shape[1] != 0:
                    hidden_activity_arr, cec_activity_arr = self.__lstm_forward(
                        x_arr, hidden_activity_arr, cec_activity_arr)
                    skip_flag = False
                else:
                    skip_flag = True

            if skip_flag is False:
                add_arr = torch.unsqueeze(hidden_activity_arr, axis=1)
                if pred_arr is None:
                    pred_arr = add_arr
                else:
                    pred_arr = torch.cat((pred_arr, add_arr), dim=1)

        return pred_arr

    def __lstm_forward(self, observed_arr, hidden_activity_arr,
                       cec_activity_arr):
        '''
        Forward propagate in LSTM gate.
        
        Args:
            F:                      `mxnet.ndarray` or `mxnet.symbol`.
            observed_arr:           rank-2 tensor of observed data points.
            hidden_activity_arr:    rank-2 tensor of activities in hidden layer.
            cec_activity_arr:       rank-2 tensor of activities in the constant error carousel.
        
        Returns:
            Tuple data.
            - rank-2 tensor of activities in hidden layer,
            - rank-2 tensor of activities in LSTM gate.
        '''
        if len(observed_arr.shape) == 3:
            observed_arr = observed_arr[:, 0]

        if self.__input_adjusted_flag is True:
            observed_arr = torch.div(observed_arr,
                                     torch.sum(torch.ones_like(observed_arr)))

        observed_lstm_matrix = self.observed_fc(observed_arr)

        # using bias
        hidden_lstm_matrix = self.hidden_fc(hidden_activity_arr)
        lstm_matrix = observed_lstm_matrix + hidden_lstm_matrix

        given_activity_arr = lstm_matrix[:, :self.__hidden_n]
        input_gate_activity_arr = lstm_matrix[:,
                                              self.__hidden_n:self.__hidden_n *
                                              2]
        forget_gate_activity_arr = lstm_matrix[:, self.__hidden_n *
                                               2:self.__hidden_n * 3]
        output_gate_activity_arr = lstm_matrix[:, self.__hidden_n *
                                               3:self.__hidden_n * 4]

        # no bias
        _input_gate_activity_arr = self.input_gate_fc(cec_activity_arr)
        input_gate_activity_arr = input_gate_activity_arr + _input_gate_activity_arr
        # no bias
        _forget_gate_activity_arr = self.forget_gate_fc(cec_activity_arr)
        forget_gate_activity_arr = forget_gate_activity_arr + _forget_gate_activity_arr
        given_activity_arr = self.__observed_activation(given_activity_arr)
        input_gate_activity_arr = self.__input_gate_activation(
            input_gate_activity_arr)
        forget_gate_activity_arr = self.__forget_gate_activation(
            forget_gate_activity_arr)

        # rank-2
        _cec_activity_arr = torch.mul(
            given_activity_arr, input_gate_activity_arr) + torch.mul(
                forget_gate_activity_arr, cec_activity_arr)

        # no bias
        _output_gate_activity_arr = self.output_gate_fc(_cec_activity_arr)

        output_gate_activity_arr = output_gate_activity_arr + _output_gate_activity_arr
        output_gate_activity_arr = self.__output_gate_activation(
            output_gate_activity_arr)

        # rank-2
        _hidden_activity_arr = torch.mul(
            output_gate_activity_arr,
            self.__hidden_activation(_cec_activity_arr))

        return (_hidden_activity_arr, _cec_activity_arr)

    def output_forward_propagate(self, pred_arr):
        '''
        Forward propagation in output layer.
        
        Args:
            F:                   `mxnet.ndarray` or `mxnet.symbol`.
            pred_arr:            rank-3 tensor of predicted data points.

        Returns:
            rank-3 tensor of propagated data points.
        '''
        if self.__output_layer_flag is False:
            return pred_arr

        batch_size = pred_arr.shape[0]
        seq_len = pred_arr.shape[1]
        # Different from mxnet version.
        pred_arr = self.output_fc(torch.reshape(pred_arr, (batch_size, -1)))
        if self.__output_activation == "identity_adjusted":
            pred_arr = torch.div(pred_arr,
                                 torch.sum(torch.ones_like(pred_arr)))
        elif self.__output_activation != "identity":
            pred_arr = self.__output_activation(pred_arr)
        pred_arr = torch.reshape(pred_arr, (batch_size, seq_len, -1))
        return pred_arr

    def save_parameters(self, filename):
        '''
        Save parameters to files.

        Args:
            filename:       File name.
        '''
        torch.save(
            {
                'epoch': self.epoch,
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': self.loss_arr,
                'input_dim': self.__input_dim,
                'input_seq_len': self.__input_seq_len,
            }, filename)

    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        checkpoint = torch.load(filename)
        self.initialize_params(
            input_dim=checkpoint["input_dim"],
            input_seq_len=checkpoint["input_seq_len"],
        )
        self.load_state_dict(checkpoint['model_state_dict'], strict=strict)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epoch = checkpoint['epoch']
        self.__loss_list = checkpoint['loss'].tolist()
        if ctx is not None:
            self.to(ctx)
            self.__ctx = ctx

    def set_readonly(self, value):
        ''' setter for losses. '''
        raise TypeError("This property must be read-only.")

    __loss_list = []

    def get_loss_arr(self):
        ''' getter for losses. '''
        return np.array(self.__loss_list)

    loss_arr = property(get_loss_arr, set_readonly)

    def get_output_dim(self):
        return self.__output_dim

    output_dim = property(get_output_dim, set_readonly)

    def get_init_deferred_flag(self):
        ''' getter for `bool` that means initialization in this class will be deferred or not.'''
        return self.__init_deferred_flag

    def set_init_deferred_flag(self, value):
        ''' setter for `bool` that means initialization in this class will be deferred or not. '''
        self.__init_deferred_flag = value

    init_deferred_flag = property(get_init_deferred_flag,
                                  set_init_deferred_flag)