示例#1
0
def save_model(epoch: int,
               model: LSTMTagger,
               optimizer: Adam,
               av_train_losses: List[float],
               av_eval_losses: List[float],
               model_file_name: str,
               word_to_ix: Union[BertTokenToIx, defaultdict],
               ix_to_word: Union[BertIxToToken, defaultdict],
               word_vocab: Optional[Vocab],
               tag_vocab: Vocab,
               char_to_ix: DefaultDict[str, int],
               models_folder: str,
               embedding_dim: int,
               char_embedding_dim: int,
               hidden_dim: int,
               char_hidden_dim: int,
               accuracy: float64,
               av_eval_loss: float,
               micro_precision: float64,
               micro_recall: float64,
               micro_f1: float64,
               weighted_macro_precision: float64,
               weighted_macro_recall: float64,
               weighted_macro_f1: float64,
               use_bert_cased: bool,
               use_bert_uncased: bool,
               use_bert_large: bool
               ) -> None:
    try:
        os.remove(os.path.join("..", models_folder, model_file_name))
    except FileNotFoundError:
        pass
    torch.save({
            'checkpoint_epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'av_train_losses': av_train_losses,
            'av_eval_losses': av_eval_losses,
            'word_to_ix': word_to_ix,
            'ix_to_word': ix_to_word,
            'word_vocab': word_vocab,
            'tag_vocab': tag_vocab,
            'char_to_ix': char_to_ix,
            'embedding_dim': embedding_dim,
            'char_embedding_dim': char_embedding_dim,
            'hidden_dim': hidden_dim,
            'char_hidden_dim': char_hidden_dim,
            'accuracy': accuracy,
            'av_eval_loss': av_eval_loss,
            'micro_precision': micro_precision,
            'micro_recall': micro_recall,
            'micro_f1': micro_f1,
            'weighted_macro_precision': weighted_macro_precision,
            'weighted_macro_recall': weighted_macro_recall,
            'weighted_macro_f1': weighted_macro_f1,
            'use_bert_cased': use_bert_cased,
            'use_bert_uncased': use_bert_uncased,
            'use_bert_large': use_bert_large
    }, os.path.join("..", models_folder, model_file_name))
    print("Model with lowest average eval loss successfully saved as: "+os.path.join("..", models_folder, model_file_name))
 def from_instances(cls, model: WaveGlow, optimizer: Adam, hparams: HParams,
                    iteration: int):
     result = cls(
         state_dict=model.state_dict(),
         optimizer=optimizer.state_dict(),
         learning_rate=hparams.learning_rate,
         iteration=iteration,
         hparams=asdict(hparams),
     )
     return result
 def from_instances(cls, model: Tacotron2, optimizer: Adam,
                    hparams: HParams, iteration: int, symbols: SymbolIdDict,
                    accents: AccentsDict, speakers: SpeakersDict):
     result = cls(state_dict=model.state_dict(),
                  optimizer=optimizer.state_dict(),
                  learning_rate=hparams.learning_rate,
                  iteration=iteration,
                  hparams=asdict(hparams),
                  symbols=symbols.raw(),
                  accents=accents.raw(),
                  speakers=speakers.raw())
     return result
示例#4
0
def main(args):
    """Train/ Cross validate for data source = YogiDB."""
    # Create data loader
    """Generic(data.Dataset)(image_set, annotations,
                     is_train=True, inp_res=256, out_res=64, sigma=1,
                     scale_factor=0, rot_factor=0, label_type='Gaussian',
                     rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV)."""
    annotations_source = 'basic-thresholder'

    # Get the data from yogi
    db_obj = YogiDB(config.db_url)
    imageset = db_obj.get_filtered(ImageSet,
                                   name=args.image_set_name)
    annotations = db_obj.get_annotations(image_set_name=args.image_set_name,
                                         annotation_source=annotations_source)
    pts = torch.Tensor(annotations[0]['joint_self'])
    num_classes = pts.size(0)
    crop_size = 512
    if args.crop:
        crop_size = args.crop
        crop = True
    else:
        crop = False

    # Using the default RGB mean and std dev as 0
    RGB_MEAN = torch.as_tensor([0.0, 0.0, 0.0])
    RGB_STDDEV = torch.as_tensor([0.0, 0.0, 0.0])

    dataset = Generic(image_set=imageset,
                      inp_res=args.inp_res,
                      out_res=args.out_res,
                      annotations=annotations,
                      mode=args.mode,
                      crop=crop, crop_size=crop_size,
                      rgb_mean=RGB_MEAN, rgb_stddev=RGB_STDDEV)

    train_dataset = dataset
    train_dataset.is_train = True
    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch, shuffle=True,
                              num_workers=args.workers, pin_memory=True)

    val_dataset = dataset
    val_dataset.is_train = False
    val_loader = DataLoader(val_dataset,
                            batch_size=args.test_batch, shuffle=False,
                            num_workers=args.workers, pin_memory=True)

    # Select the hardware device to use for inference.
    if torch.cuda.is_available():
        device = torch.device('cuda', torch.cuda.current_device())
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # Disable gradient calculations by default.
    torch.set_grad_enabled(False)

    # create checkpoint dir
    os.makedirs(args.checkpoint, exist_ok=True)

    if args.arch == 'hg1':
        model = hg1(pretrained=False, num_classes=num_classes)
    elif args.arch == 'hg2':
        model = hg2(pretrained=False, num_classes=num_classes)
    elif args.arch == 'hg8':
        model = hg8(pretrained=False, num_classes=num_classes)
    else:
        raise Exception('unrecognised model architecture: ' + args.model)

    model = DataParallel(model).to(device)

    if args.optimizer == "Adam":
        optimizer = Adam(model.parameters(),
                         lr=args.lr,
                         momentum=args.momentum,
                         weight_decay=args.weight_decay)
    else:
        optimizer = RMSprop(model.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    best_acc = 0

    # optionally resume from a checkpoint
    title = args.data_identifier + ' ' + args.arch
    if args.resume:
        assert os.path.isfile(args.resume)
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(args.resume, checkpoint['epoch']))
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule, args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # train for one epoch
        train_loss, train_acc = do_training_epoch(train_loader, model, device, optimizer)

        # evaluate on validation set
        if args.debug == 1:
            valid_loss, valid_acc, predictions, validation_log = do_validation_epoch(val_loader, model, device, False, True, os.path.join(args.checkpoint, 'debug.csv'), epoch + 1)
        else:
            valid_loss, valid_acc, predictions, _ = do_validation_epoch(val_loader, model, device, False)

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc])

        # remember best acc and save checkpoint
        is_best = valid_acc > best_acc
        best_acc = max(valid_acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
        }, predictions, is_best, checkpoint=args.checkpoint, snapshot=args.snapshot)

    logger.close()
    logger.plot(['Train Acc', 'Val Acc'])
    savefig(os.path.join(args.checkpoint, 'log.eps'))
示例#5
0
class ChatSpaceTrainer:
    def __init__(
        self,
        config,
        model: ChatSpaceModel,
        vocab: Vocab,
        device: torch.device,
        train_corpus_path,
        eval_corpus_path=None,
        encoding="utf-8",
    ):
        self.config = config
        self.device = device
        self.model = model
        self.optimizer = Adam(self.model.parameters(),
                              lr=config["learning_rate"])
        self.criterion = nn.NLLLoss()
        self.vocab = vocab
        self.encoding = encoding

        self.train_corpus = DynamicCorpus(train_corpus_path,
                                          repeat=True,
                                          encoding=self.encoding)
        self.train_dataset = ChatSpaceDataset(config,
                                              self.train_corpus,
                                              self.vocab,
                                              with_random_space=True)

        if eval_corpus_path is not None:
            self.eval_corpus = DynamicCorpus(eval_corpus_path,
                                             encoding=self.encoding)
            self.eval_dataset = ChatSpaceDataset(self.config,
                                                 eval_corpus_path,
                                                 self.vocab,
                                                 with_random_space=True)

        self.global_epochs = 0
        self.global_steps = 0

    def eval(self, batch_size=64):
        self.model.eval()

        with torch.no_grad():
            eval_output = self.run_epoch(self.eval_dataset,
                                         batch_size=batch_size,
                                         is_train=False)

        self.model.train()
        return eval_output

    def train(self, epochs=10, batch_size=64):
        for epoch_id in range(epochs):
            self.run_epoch(
                self.train_dataset,
                batch_size=batch_size,
                epoch_id=epoch_id,
                is_train=True,
                log_freq=self.config["logging_step"],
            )
            self.save_checkpoint(
                f"outputs/checkpoints/checkpoint_ep{epoch_id}.cpt")
            self.save_model(f"outputs/models/chatspace_ep{epoch_id}.pt")
            self.save_model(f"outputs/jit_models/chatspace_ep{epoch_id}.pt",
                            as_jit=False)

    def run_epoch(self,
                  dataset,
                  batch_size=64,
                  epoch_id=0,
                  is_train=True,
                  log_freq=100):
        step_outputs, step_metrics, step_inputs = [], [], []
        collect_fn = (ChatSpaceDataset.train_collect_fn
                      if is_train else ChatSpaceDataset.eval_collect_fn)
        data_loader = DataLoader(dataset, batch_size, collate_fn=collect_fn)
        for step_num, batch in enumerate(data_loader):
            batch = {
                key: value.to(self.device)
                for key, value in batch.items()
            }
            output = self.step(step_num, batch)

            if is_train:
                self.update(output["loss"])

            if not is_train or step_num % log_freq == 0:
                batch = {
                    key: value.cpu().numpy()
                    for key, value in batch.items()
                }
                output = {
                    key: value.detach().cpu().numpy()
                    for key, value in output.items()
                }

                metric = self.step_metric(output["output"], batch,
                                          output["loss"])

                if is_train:
                    print(
                        f"EPOCH:{epoch_id}",
                        f"STEP:{step_num}/{len(data_loader)}",
                        [(key + ":" + "%.3f" % metric[key]) for key in metric],
                    )
                else:
                    step_outputs.append(output)
                    step_metrics.append(metric)
                    step_inputs.append(batch)

        if not is_train:
            return self.epoch_metric(step_inputs, step_outputs, step_metrics)

        if is_train:
            self.global_epochs += 1

    def epoch_metric(self, step_inputs, step_outputs, step_metrics):
        average_loss = np.mean([metric["loss"] for metric in step_metrics])

        epoch_inputs = [
            example for step_input in step_inputs
            for example in step_input["input"].tolist()
        ]
        epoch_outputs = [
            example for output in step_outputs
            for example in output["output"].argmax(axis=-1).tolist()
        ]
        epoch_labels = [
            example for step_input in step_inputs
            for example in step_input["label"].tolist()
        ]

        epoch_metric = calculated_metric(batch_input=epoch_inputs,
                                         batch_output=epoch_outputs,
                                         batch_label=epoch_labels)

        epoch_metric["loss"] = average_loss
        return epoch_metric

    def step_metric(self, output, batch, loss=None):
        metric = calculated_metric(
            batch_input=batch["input"].tolist(),
            batch_output=output.argmax(axis=-1).tolist(),
            batch_label=batch["label"].tolist(),
        )

        if loss is not None:
            metric["loss"] = loss
        return metric

    def step(self, step_num, batch, with_loss=True, is_train=True):
        output = self.model.forward(batch["input"], batch["length"])
        if is_train:
            self.global_steps += 1

        if not with_loss:
            return {"output": output}

        loss = self.criterion(output.transpose(1, 2), batch["label"])
        return {"loss": loss, "output": output}

    def update(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def save_model(self, path, as_jit=False):
        self.optimizer.zero_grad()
        params = [{
            "param": param,
            "require_grad": param.requires_grad
        } for param in self.model.parameters()]

        for param in params:
            param["param"].require_grad = False

        with torch.no_grad():
            if not as_jit:
                torch.save(self.model.state_dict(), path)
            else:
                self.model.cpu().eval()

                sample_texts = ["오늘 너무 재밌지 않았어?", "너랑 하루종일 놀아서 기분이 좋았어!"]
                dataset = ChatSpaceDataset(self.config,
                                           sample_texts,
                                           self.vocab,
                                           with_random_space=False)
                data_loader = DataLoader(dataset,
                                         batch_size=2,
                                         collate_fn=dataset.eval_collect_fn)

                for batch in data_loader:
                    model_input = (batch["input"].detach(),
                                   batch["length"].detach())
                    traced_model = torch.jit.trace(self.model, model_input)
                    torch.jit.save(traced_model, path)
                    break

                self.model.to(self.device).train()

        print(f"Model Saved on {path}{' as_jit' if as_jit else ''}")

        for param in params:
            if param["require_grad"]:
                param["param"].require_grad = True

    def save_checkpoint(self, path):
        torch.save(
            {
                "epoch": self.global_epochs,
                "steps": self.global_steps,
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
            },
            path,
        )

    def load_checkpoint(self, path):
        checkpoint = torch.load(path)
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.global_epochs = checkpoint["epoch"]
        self.global_steps = checkpoint["steps"]

    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path))
示例#6
0
class Trainer:
    def __init__(self, args, data_loader):
        self.args = args
        self.data_loader = data_loader
        self.metric = PSNR()

        if args.is_perceptual_oriented:
            self.lr = args.p_lr
            self.content_loss_factor = args.p_content_loss_factor
            self.perceptual_loss_factor = args.p_perceptual_loss_factor
            self.adversarial_loss_factor = args.p_adversarial_loss_factor
            self.decay_iter = args.p_decay_iter
        else:
            self.lr = args.g_lr
            self.content_loss_factor = args.g_content_loss_factor
            self.perceptual_loss_factor = args.g_perceptual_loss_factor
            self.adversarial_loss_factor = args.g_adversarial_loss_factor
            self.decay_iter = args.g_decay_iter

        self.build_model(args)
        self.build_optimizer(args)
        if args.fp16: self.initialize_model_opt_fp16()
        if args.distributed: self.parallelize_model()
        self.history = {
            n: []
            for n in [
                'adversarial_loss', 'discriminator_loss', 'perceptual_loss',
                'content_loss', 'generator_loss', 'score'
            ]
        }
        if args.load: self.load_model(args)
        if args.resume: self.resume(args)
        self.build_scheduler(args)
        print(':D')

    def train(self, args):
        total_step = len(self.data_loader)
        adversarial_criterion = nn.BCEWithLogitsLoss().cuda()
        content_criterion = nn.L1Loss().cuda()
        perception_criterion = PerceptualLoss().cuda()
        self.best_score = -9999.
        self.generator.train()
        self.discriminator.train()

        print(f"{'epoch':>7s}"
              f"{'batch':>10s}"
              f"{'discr.':>10s}"
              f"{'gener.':>10s}"
              f"{'adver.':>10s}"
              f"{'percp.':>10s}"
              f"{'contn.':>10s}"
              f"{'PSNR':>10s}"
              f"")

        for epoch in range(args.epoch, args.num_epoch):
            sample_dir_epoch = Path(
                args.checkpoint_dir) / 'sample_dir' / str(epoch)
            sample_dir_epoch.mkdir(exist_ok=True, parents=True)

            for step, image in enumerate(self.data_loader):
                low_resolution = image['lr'].cuda()
                high_resolution = image['hr'].cuda()

                real_labels = torch.ones((high_resolution.size(0), 1)).cuda()
                fake_labels = torch.zeros((high_resolution.size(0), 1)).cuda()

                ##########################
                #   training generator   #
                ##########################
                self.optimizer_generator.zero_grad()
                fake_high_resolution = self.generator(low_resolution)

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution)
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(
                    discriminator_rf, fake_labels)
                adversarial_loss_fr = adversarial_criterion(
                    discriminator_fr, real_labels)
                adversarial_loss = (adversarial_loss_fr +
                                    adversarial_loss_rf) / 2

                perceptual_loss = perception_criterion(high_resolution,
                                                       fake_high_resolution)
                content_loss = content_criterion(fake_high_resolution,
                                                 high_resolution)

                generator_loss = (
                    adversarial_loss * self.adversarial_loss_factor +
                    perceptual_loss * self.perceptual_loss_factor +
                    content_loss * self.content_loss_factor)

                if args.fp16:
                    with apex.amp.scale_loss(
                            generator_loss,
                            self.optimizer_generator) as scaled_loss:
                        scaled_loss.backward()
                else:
                    generator_loss.backward()

                self.optimizer_generator.step()

                ##########################
                # training discriminator #
                ##########################

                self.optimizer_discriminator.zero_grad()

                score_real = self.discriminator(high_resolution)
                score_fake = self.discriminator(fake_high_resolution.detach())
                discriminator_rf = score_real - score_fake.mean()
                discriminator_fr = score_fake - score_real.mean()

                adversarial_loss_rf = adversarial_criterion(
                    discriminator_rf, real_labels)
                adversarial_loss_fr = adversarial_criterion(
                    discriminator_fr, fake_labels)
                discriminator_loss = (adversarial_loss_fr +
                                      adversarial_loss_rf) / 2

                if args.fp16:
                    with apex.amp.scale_loss(
                            discriminator_loss,
                            self.optimizer_discriminator) as scaled_loss:
                        scaled_loss.backward()
                else:
                    discriminator_loss.backward()

                self.optimizer_discriminator.step()

                for _ in range(self.n_unit_scheduler_step):
                    self.lr_scheduler_generator.step()
                    self.lr_scheduler_discriminator.step()
                    self.unit_scheduler_step += 1

                score = self.metric(fake_high_resolution.detach(),
                                    high_resolution)
                print(
                    f'\r'
                    f"{epoch:>3d}:{args.num_epoch:<3d}"
                    f"{step:>5d}:{total_step:<4d}"
                    f"{discriminator_loss.item():>10.4f}"
                    f"{generator_loss.item():>10.4f}"
                    f"{adversarial_loss.item()*self.adversarial_loss_factor:>10.4f}"
                    f"{perceptual_loss.item()*self.perceptual_loss_factor:>10.4f}"
                    f"{content_loss.item()*self.content_loss_factor:>10.4f}"
                    f"{score.item():>10.4f}",
                    end='')

                if step % 1000 == 0:
                    if step % 5000 == 0:
                        result = torch.cat(
                            (high_resolution, fake_high_resolution), 2)
                        save_image(result, sample_dir_epoch / f"SR_{step}.png")

            self.history['adversarial_loss'].append(
                adversarial_loss.item() * self.adversarial_loss_factor)
            self.history['discriminator_loss'].append(
                discriminator_loss.item())
            self.history['perceptual_loss'].append(perceptual_loss.item() *
                                                   self.perceptual_loss_factor)
            self.history['content_loss'].append(content_loss.item() *
                                                self.content_loss_factor)
            self.history['generator_loss'].append(generator_loss.item())
            self.history['score'].append(score.item())

            self.save(epoch, 'last.pth')
            if score > self.best_score:
                self.best_score = score
                self.save(epoch, 'best.pth')

    def build_model(self, args):
        self.generator = ESRGAN(3, 3, 64,
                                scale_factor=args.scale_factor).cuda()
        self.discriminator = Discriminator().cuda()

    def build_optimizer(self, args):
        self.optimizer_generator = Adam(self.generator.parameters(),
                                        lr=self.lr,
                                        betas=(args.b1, args.b2),
                                        weight_decay=args.weight_decay)
        self.optimizer_discriminator = Adam(self.discriminator.parameters(),
                                            lr=self.lr,
                                            betas=(args.b1, args.b2),
                                            weight_decay=args.weight_decay)

    def initialize_model_opt_fp16(self):
        self.generator, self.optimizer_generator = apex.amp.initialize(
            self.generator, self.optimizer_generator, opt_level='O2')
        self.discriminator, self.optimizer_discriminator = apex.amp.initialize(
            self.discriminator, self.optimizer_discriminator, opt_level='O2')

    def parallelize_model(self):
        self.generator = apex.parallel.DistributedDataParallel(
            self.generator, delay_allreduce=True)
        self.discriminator = apex.parallel.DistributedDataParallel(
            self.discriminator, delay_allreduce=True)

    def build_scheduler(self, args):
        if not hasattr(self, 'unit_scheduler_step'):
            self.unit_scheduler_step = -1
        self.n_unit_scheduler_step = (args.batch_size // 16) * args.nodes
        print(f'Batch size: {args.batch_size}. '
              f'Number of nodes: {args.nodes}. '
              f'Each step here equates to {self.n_unit_scheduler_step} '
              f'unit scheduler step in the paper.\n'
              f'Current unit scheduler step: {self.unit_scheduler_step}.')
        self.lr_scheduler_generator = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_generator,
            milestones=self.decay_iter,
            gamma=.5,
            last_epoch=self.unit_scheduler_step if args.resume else -1)
        self.lr_scheduler_discriminator = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_discriminator,
            milestones=self.decay_iter,
            gamma=.5,
            last_epoch=self.unit_scheduler_step if args.resume else -1)

    def load_model(self, args):
        path_to_load = Path(args.load)
        if path_to_load.is_file():
            cpt = torch.load(path_to_load,
                             map_location=lambda storage, loc: storage.cuda())
            g_sdict = cpt['g_state_dict']
            d_sdict = cpt['d_state_dict']
            if g_sdict is not None:
                if args.distributed == False:
                    g_sdict = {
                        k[7:] if k.startswith('module.') else k: v
                        for k, v in g_sdict.items()
                    }
                self.generator.load_state_dict(g_sdict)
                print(f'[*] Loading generator from {path_to_load}')
            if d_sdict is not None:
                if args.distributed == False:
                    d_sdict = {
                        k[7:] if k.startswith('module.') else k: v
                        for k, v in d_sdict.items()
                    }
                self.discriminator.load_state_dict(d_sdict)
                print(f'[*] Loading discriminator from {path_to_load}')
            if args.fp16 and cpt['amp'] is not None:
                apex.amp.load_state_dict(cpt['amp'])
        else:
            print(f'[!] No checkpoint found at {path_to_load}')

    def resume(self, args):
        path_to_resume = Path(args.resume)
        if path_to_resume.is_file():
            cpt = torch.load(path_to_resume,
                             map_location=lambda storage, loc: storage.cuda())
            if cpt['epoch'] is not None: args.epoch = cpt['epoch'] + 1
            if cpt['unit_scheduler_step'] is not None:
                self.unit_scheduler_step = cpt['unit_scheduler_step'] + 1
            if cpt['history'] is not None: self.history = cpt['history']
            g_sdict, d_sdict = cpt['g_state_dict'], cpt['d_state_dict']
            optg_sdict = cpt['opt_g_state_dict']
            optd_sdict = cpt['opt_d_state_dict']
            if g_sdict is not None:
                if args.distributed == False:
                    g_sdict = {
                        k[7:] if k.startswith('module.') else k: v
                        for k, v in g_sdict.items()
                    }
                self.generator.load_state_dict(g_sdict)
                print(f'[*] Loading generator from {path_to_resume}')
            if d_sdict is not None:
                if args.distributed == False:
                    d_sdict = {
                        k[7:] if k.startswith('module.') else k: v
                        for k, v in d_sdict.items()
                    }
                self.discriminator.load_state_dict(d_sdict)
                print(f'[*] Loading discriminator from {path_to_resume}')
            if optg_sdict is not None:
                self.optimizer_generator.load_state_dict(optg_sdict)
                print(f'[*] Loading generator optmizer from {path_to_resume}')
            if optd_sdict is not None:
                self.optimizer_discriminator.load_state_dict(optd_sdict)
                print(f'[*] Loading discriminator optmizer '
                      f'from {path_to_resume}')
            if args.fp16 and cpt['amp'] is not None:
                apex.amp.load_state_dict(cpt['amp'])
        else:
            raise ValueError(
                f'[!] No checkpoint to resume from at {path_to_resume}')

    def save(self, epoch, filename):
        g_sdict = self.generator.state_dict()
        d_sdict = self.discriminator.state_dict()
        if self.args.distributed == False:
            g_sdict = {f'module.{k}': v for k, v in g_sdict.items()}
            d_sdict = {f'module.{k}': v for k, v in d_sdict.items()}
        save_dict = {
            'epoch': epoch,
            'unit_scheduler_step': self.unit_scheduler_step,
            'history': self.history,
            'g_state_dict': g_sdict,
            'd_state_dict': d_sdict,
            'opt_g_state_dict': self.optimizer_generator.state_dict(),
            'opt_d_state_dict': self.optimizer_discriminator.state_dict(),
            'amp': apex.amp.state_dict() if self.args.fp16 else None,
            'args': self.args
        }
        torch.save(save_dict, Path(self.args.checkpoint_dir) / filename)
示例#7
0
                    # Batch logs
                    pbar.set_description(
                        'VALIDATION: Current epoch: {}; '
                        'Loss {loss.val:.4f} ({loss.avg:.4f}); '
                        'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                            epoch, loss=run_val_loss, top1=run_val_acc))

                # Epoch logs
                logger.add_scalars('epoch/acc', {'val': run_val_acc.avg},
                                   epoch)

            # save best model and latest model with all required parameteres
            if run_val_acc.avg > best_metric:
                best_metric = run_val_acc.avg
                save_best = True
            else:
                save_best = False

            save_checkpoint(
                OUT_DIR, 'model', {
                    'epoch':
                    epoch,
                    'state_dict':
                    model.module.state_dict()
                    if gpus and len(gpus) > 1 else model.state_dict(),
                    'optimizer':
                    optimizer.state_dict(),
                    'scheduler':
                    scheduler.state_dict(),
                }, save_best)
示例#8
0
class NeuralNetworks(nn.Module, ObservableData):
    '''
    Neural Networks.

    References:
        - Kamyshanska, H., & Memisevic, R. (2014). The potential energy of an autoencoder. IEEE transactions on pattern analysis and machine intelligence, 37(6), 1261-1273.
    '''

    # `bool` that means initialization in this class will be deferred or not.
    __init_deferred_flag = False

    def __init__(
        self,
        computable_loss,
        initializer_f=None,
        optimizer_f=None,
        learning_rate=1e-05,
        units_list=[100, 1],
        dropout_rate_list=[0.0, 0.5],
        activation_list=[
            torch.nn.functional.tanh, torch.nn.functional.sigmoid
        ],
        hidden_batch_norm_list=[100, None],
        ctx="cpu",
        regularizatable_data_list=[],
        scale=1.0,
        output_no_bias_flag=False,
        all_no_bias_flag=False,
        not_init_flag=False,
    ):
        '''
        Init.

        Args:
            computable_loss:                is-a `ComputableLoss` or `nn.modules.loss._Loss`.
            initializer_f:                  A function that contains `torch.nn.init`.
                                            This function receive `tensor` as input and output initialized `tensor`. 
                                            If `None`, it is drawing from the Xavier distribution.

            optimizer_f:                    A function that contains `torch.optim.optimizer.Optimizer` for parameters of model.
                                            This function receive `self.parameters()` as input and output `torch.optim.optimizer.Optimizer`.

            learning_rate:                  `float` of learning rate.
            units_list:                     `list` of int` of the number of units in hidden/output layers.
            dropout_rate_list:              `list` of `float` of dropout rate.
            activation_list:                `list` of act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in input gate.
            hidden_batch_norm_list:         `list` of `mxnet.gluon.nn.BatchNorm`.
            ctx:                            Context-manager that changes the selected device.
            regularizatable_data_list:           `list` of `RegularizatableData`.
            scale:                          `float` of scaling factor for initial parameters.
            output_no_bias_flag:            `bool` for using bias or not in output layer(last hidden layer).
            all_no_bias_flag:               `bool` for using bias or not in all layer.
            not_init_flag:                  `bool` of whether initialize parameters or not.
        '''
        super(NeuralNetworks, self).__init__()

        if isinstance(computable_loss, ComputableLoss) is False and isinstance(
                computable_loss, nn.modules.loss._Loss) is False:
            raise TypeError(
                "The type of `computable_loss` must be `ComputableLoss` or `nn.modules.loss._Loss`."
            )

        if len(units_list) != len(activation_list):
            raise ValueError(
                "The length of `units_list` and `activation_list` must be equivalent."
            )
        self.__units_list = units_list

        if len(dropout_rate_list) != len(units_list):
            raise ValueError(
                "The length of `dropout_rate_list` and `activation_list` must be equivalent."
            )

        self.initializer_f = initializer_f
        self.optimizer_f = optimizer_f
        self.__units_list = units_list
        self.__all_no_bias_flag = all_no_bias_flag
        self.__output_no_bias_flag = output_no_bias_flag

        self.dropout_forward_list = [None] * len(dropout_rate_list)
        for i in range(len(dropout_rate_list)):
            self.dropout_forward_list[i] = nn.Dropout(p=dropout_rate_list[i])
        self.dropout_forward_list = nn.ModuleList(self.dropout_forward_list)

        self.hidden_batch_norm_list = [None] * len(hidden_batch_norm_list)
        for i in range(len(hidden_batch_norm_list)):
            if hidden_batch_norm_list[i] is not None:
                if isinstance(hidden_batch_norm_list[i], int) is True:
                    self.hidden_batch_norm_list[i] = nn.BatchNorm1d(
                        hidden_batch_norm_list[i])
                else:
                    self.hidden_batch_norm_list[i] = hidden_batch_norm_list[i]

        self.hidden_batch_norm_list = nn.ModuleList(
            self.hidden_batch_norm_list)

        self.__not_init_flag = not_init_flag
        self.activation_list = activation_list

        self.__computable_loss = computable_loss
        self.__learning_rate = learning_rate

        for v in regularizatable_data_list:
            if isinstance(v, RegularizatableData) is False:
                raise TypeError(
                    "The type of values of `regularizatable_data_list` must be `RegularizatableData`."
                )
        self.__regularizatable_data_list = regularizatable_data_list

        self.__ctx = ctx

        self.fc_list = []
        self.flatten = nn.Flatten()

        self.epoch = 0
        self.__loss_list = []
        logger = getLogger("accelbrainbase")
        self.__logger = logger
        self.__input_dim = None

    def initialize_params(self, input_dim):
        '''
        Initialize params.

        Args:
            input_dim:      The number of units in input layer.
        '''
        if self.__input_dim is not None:
            return
        self.__input_dim = input_dim

        if len(self.fc_list) > 0:
            return

        if self.__all_no_bias_flag is True:
            use_bias = False
        else:
            use_bias = True

        fc = nn.Linear(input_dim, self.__units_list[0], bias=use_bias)
        if self.initializer_f is None:
            fc.weight = torch.nn.init.xavier_normal_(fc.weight, gain=1.0)
        else:
            fc.weight = self.initializer_f(fc.weight)

        fc_list = [fc]

        for i in range(1, len(self.__units_list)):
            if self.__all_no_bias_flag is True:
                use_bias = False
            elif self.__output_no_bias_flag is True and i + 1 == len(
                    self.__units_list):
                use_bias = False
            else:
                use_bias = True

            fc = nn.Linear(self.__units_list[i - 1],
                           self.__units_list[i],
                           bias=use_bias)

            if self.initializer_f is None:
                fc.weight = torch.nn.init.xavier_normal_(fc.weight, gain=1.0)
            else:
                fc.weight = self.initializer_f(fc.weight)

            fc_list.append(fc)

        self.fc_list = nn.ModuleList(fc_list)
        self.to(self.__ctx)

        if self.init_deferred_flag is False:
            if self.__not_init_flag is False:
                if self.optimizer_f is None:
                    self.optimizer = Adam(
                        self.parameters(),
                        lr=self.__learning_rate,
                    )
                else:
                    self.optimizer = self.optimizer_f(self.parameters(), )

    def learn(self, iteratable_data):
        '''
        Learn samples drawn by `IteratableData.generate_learned_samples()`.

        Args:
            iteratable_data:     is-a `IteratableData`.
        '''
        if isinstance(iteratable_data, IteratableData) is False:
            raise TypeError(
                "The type of `iteratable_data` must be `IteratableData`.")

        self.__loss_list = []
        learning_rate = self.__learning_rate
        try:
            epoch = self.epoch
            iter_n = 0
            for batch_observed_arr, batch_target_arr, test_batch_observed_arr, test_batch_target_arr in iteratable_data.generate_learned_samples(
            ):
                self.initialize_params(
                    input_dim=self.flatten(batch_observed_arr).shape[-1])
                self.optimizer.zero_grad()
                # rank-3
                pred_arr = self.inference(batch_observed_arr)
                loss = self.compute_loss(pred_arr, batch_target_arr)
                loss.backward()
                self.optimizer.step()
                self.regularize()

                if (iter_n + 1) % int(
                        iteratable_data.iter_n / iteratable_data.epochs) == 0:
                    with torch.inference_mode():
                        # rank-3
                        test_pred_arr = self.inference(test_batch_observed_arr)

                        test_loss = self.compute_loss(test_pred_arr,
                                                      test_batch_target_arr)
                    _loss = loss.to('cpu').detach().numpy().copy()
                    _test_loss = test_loss.to('cpu').detach().numpy().copy()
                    self.__loss_list.append((_loss, _test_loss))
                    self.__logger.debug("Epochs: " + str(epoch + 1) +
                                        " Train loss: " + str(_loss) +
                                        " Test loss: " + str(_test_loss))
                    epoch += 1
                iter_n += 1

        except KeyboardInterrupt:
            self.__logger.debug("Interrupt.")

        self.epoch = epoch
        self.__logger.debug("end. ")

    def inference(self, observed_arr):
        '''
        Inference samples drawn by `IteratableData.generate_inferenced_samples()`.

        Args:
            observed_arr:   rank-2 Array like or sparse matrix as the observed data points.
                            The shape is: (batch size, feature points)

        Returns:
            `tensor` of inferenced feature points.
        '''
        return self(observed_arr)

    def compute_loss(self, pred_arr, labeled_arr):
        '''
        Compute loss.

        Args:
            pred_arr:       `mxnet.ndarray` or `mxnet.symbol`.
            labeled_arr:    `mxnet.ndarray` or `mxnet.symbol`.

        Returns:
            loss.
        '''
        return self.__computable_loss(pred_arr, labeled_arr)

    def regularize(self):
        '''
        Regularization.
        '''
        if len(self.__regularizatable_data_list) > 0:
            params_dict = self.extract_learned_dict()
            for regularizatable in self.__regularizatable_data_list:
                params_dict = regularizatable.regularize(params_dict)

            for k, params in params_dict.items():
                self.load_state_dict({k: params}, strict=False)

    def extract_learned_dict(self):
        '''
        Extract (pre-) learned parameters.

        Returns:
            `dict` of the parameters.
        '''
        params_dict = {}
        for k in self.state_dict().keys():
            params_dict.setdefault(k, self.state_dict()[k])

        return params_dict

    def forward(self, x):
        '''
        Forward with torch.

        Args:
            x:      `tensor` of observed data points.
        
        Returns:
            `tensor` of inferenced feature points.
        '''
        x = self.flatten(x)
        self.initialize_params(input_dim=x.shape[-1])
        for i in range(len(self.activation_list)):
            x = self.fc_list[i](x)

            if self.activation_list[i] == "identity_adjusted":
                x = x / torch.sum(torch.ones_like(x))
            elif self.activation_list[i] == "softmax":
                x = F.softmax(x)
            elif self.activation_list[i] == "log_softmax":
                x = F.log_softmax(x)
            elif self.activation_list[i] != "identity":
                x = self.activation_list[i](x)

            if self.dropout_forward_list[i] is not None:
                x = self.dropout_forward_list[i](x)
            if self.hidden_batch_norm_list[i] is not None:
                x = self.hidden_batch_norm_list[i](x)

        return x

    def save_parameters(self, filename):
        '''
        Save parameters to files.

        Args:
            filename:       File name.
        '''
        torch.save(
            {
                'epoch': self.epoch,
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': self.loss_arr,
                'input_dim': self.__input_dim,
            }, filename)

    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        checkpoint = torch.load(filename)
        self.initialize_params(input_dim=checkpoint["input_dim"])
        self.load_state_dict(checkpoint['model_state_dict'], strict=strict)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epoch = checkpoint['epoch']
        self.__loss_list = checkpoint['loss'].tolist()
        if ctx is not None:
            self.to(ctx)
            self.__ctx = ctx

    def set_readonly(self, value):
        ''' setter '''
        raise TypeError("This property must be read-only.")

    def get_loss_arr(self):
        ''' getter for losses. '''
        return np.array(self.__loss_list)

    loss_arr = property(get_loss_arr, set_readonly)

    def get_init_deferred_flag(self):
        ''' getter for `bool` that means initialization in this class will be deferred or not. '''
        return self.__init_deferred_flag

    def set_init_deferred_flag(self, value):
        ''' setter for `bool` that means initialization in this class will be deferred or not. '''
        self.__init_deferred_flag = value

    init_deferred_flag = property(get_init_deferred_flag,
                                  set_init_deferred_flag)

    def get_units_list(self):
        ''' getter for `list` of units in each layer. '''
        return self.__units_list

    units_list = property(get_units_list, set_readonly)
class ReactionModel(Model):
    def __init__(self,
                 save_path,
                 log_path,
                 n_depth,
                 d_features,
                 d_meta,
                 d_classifier,
                 d_output,
                 threshold,
                 stack='ReactionAttention',
                 expansion_layer='LinearExpansion',
                 mode='1d',
                 optimizer=None,
                 **kwargs):
        '''*args: n_layers, n_head, dropout, use_bottleneck, d_bottleneck'''

        super().__init__(save_path, log_path)
        self.d_output = d_output
        self.threshold = threshold

        # ----------------------------- Model ------------------------------ #
        stack_dict = {
            'ReactionAttention': ReactionAttentionStack,
            'SelfAttention': SelfAttentionStack,
            'Alternate': AlternateStack,
            'Parallel': ParallelStack,
            'ShuffleSelfAttention': ShuffleSelfAttentionStack
        }
        expansion_dict = {
            'LinearExpansion': LinearExpansion,
            'ReduceParamLinearExpansion': ReduceParamLinearExpansion,
            'ConvExpansion': ConvExpansion,
            'LinearConvExpansion': LinearConvExpansion,
            'ShuffleConvExpansion': ShuffleConvExpansion,
            'ChannelWiseConvExpansion': ChannelWiseConvExpansion,
        }
        # *args:  n_head, n_depth,d_bottleneck=256, dropout=0.1, use_bottleneck=True
        self.model = stack_dict[stack](expansion_dict[expansion_layer],
                                       n_depth=n_depth,
                                       d_features=d_features,
                                       d_meta=d_meta,
                                       **kwargs)

        # --------------------------- Classifier --------------------------- #
        if mode == '1d':
            self.classifier = LinearClassifier(d_features, d_classifier,
                                               d_output)
        elif mode == '2d':
            self.classifier = LinearClassifier(n_depth * d_features,
                                               d_classifier, d_output)
        else:
            self.classifier = None

        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            self.model.cuda()
            self.classifier.cuda()
        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(
            self.classifier.parameters())
        if optimizer == None:
            self.optimizer = Adam(self.parameters,
                                  lr=0.001,
                                  betas=(0.9, 0.999))

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000,
                                          evaluate_every_nstep=100,
                                          print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # ---------------------------- END INIT ---------------------------- #

    def train_epoch(self, train_dataloader, eval_dataloader, device, smothing):
        ''' Epoch operation in training phase'''
        if device == 'cuda':
            assert self.CUDA_AVAILABLE
        # Set model and classifier training mode
        self.model.train()
        self.classifier.train()

        total_loss = 0
        batch_counter = 0

        # update param per batch
        for batch in tqdm(train_dataloader,
                          mininterval=1,
                          desc='  - (Training)   ',
                          leave=False):  # training_data should be a iterable

            # get data from dataloader
            feature_1, feature_2, y = map(lambda x: x.to(device), batch)

            batch_size = len(feature_1)

            # forward
            self.optimizer.zero_grad()
            logits, attn = self.model(feature_1, feature_2)
            logits = logits.view(batch_size, -1)
            logits = self.classifier(logits)

            # Judge if it's a regression problem
            if self.d_output == 1:
                pred = logits.sigmoid()
                loss = mse_loss(pred, y)

            else:
                '''Do we need to apply softmax before calculating the loss?'''
                pred = logits
                loss = cross_entropy_loss(pred, y, smoothing=smothing)

            # calculate gradients
            loss.backward()

            # update parameters
            self.optimizer.step()

            # get metrics for logging
            acc = accuracy(pred, y, threshold=self.threshold)
            precision, recall, precision_avg, recall_avg = precision_recall(
                pred, y, self.d_output, threshold=self.threshold)
            total_loss += loss.item()
            batch_counter += 1

            # training control
            state_dict = self.controller(batch_counter)

            if state_dict['step_to_print']:
                self.train_logger.info(
                    '[TRAINING]   - batch: %5d, loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f'
                    % (batch_counter, loss, acc, precision[1], recall[1]))

            if state_dict['step_to_evaluate']:
                stop = self.val_epoch(eval_dataloader, device,
                                      state_dict['step'])
                state_dict['step_to_stop'] = stop
                if stop:
                    break

            if self.controller.current_step == self.controller.max_step:
                state_dict['step_to_stop'] = True
                break

        return state_dict

    def val_epoch(self, dataloader, device, step=0, plot=False):
        ''' Epoch operation in evaluation phase '''
        if device == 'cuda':
            assert self.CUDA_AVAILABLE

        # Set model and classifier training mode
        self.model.eval()
        self.classifier.eval()

        # use evaluator to calculate the average performance
        evaluator = Evaluator()

        pred_list = []
        real_list = []

        with torch.no_grad():

            for batch in tqdm(
                    dataloader,
                    mininterval=5,
                    desc='  - (Evaluation)   ',
                    leave=False):  # training_data should be a iterable

                # get data from dataloader
                feature_1, feature_2, y = map(lambda x: x.to(device), batch)
                batch_size = len(feature_1)

                # get logits
                logits, attn = self.model(feature_1, feature_2)
                logits = logits.view(batch_size, -1)
                logits = self.classifier(logits)

                if self.d_output == 1:
                    pred = logits.sigmoid()
                    loss = mse_loss(pred, y)

                else:
                    pred = logits
                    loss = cross_entropy_loss(pred, y, smoothing=False)

                acc = accuracy(pred, y, threshold=self.threshold)
                precision, recall, _, _ = precision_recall(
                    pred, y, self.d_output, threshold=self.threshold)

                # feed the metrics in the evaluator
                evaluator(loss.item(), acc.item(), precision[1].item(),
                          recall[1].item())
                '''append the results to the predict / real list for drawing ROC or PR curve.'''
                if plot:
                    pred_list += pred.tolist()
                    real_list += y.tolist()

            if plot:
                area, precisions, recalls, thresholds = pr(
                    pred_list, real_list)
                plot_pr_curve(recalls, precisions, auc=area)

            # get evaluation results from the evaluator
            loss_avg, acc_avg, pre_avg, rec_avg = evaluator.avg_results()

            self.train_logger.info(
                '[EVALUATION] - eval_loss: %3.4f, acc: %1.4f, pre: %1.4f, rec: %1.4f'
                % (loss_avg, acc_avg, pre_avg, rec_avg))

            state_dict = self.early_stopping(loss_avg)

            if state_dict['save']:
                checkpoint = {
                    'model_state_dict': self.model.state_dict(),
                    'classifier_state_dict': self.classifier.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'global_step': step
                }
                self.save_model(checkpoint,
                                self.save_path + '-loss-%.5f' % loss_avg)

            return state_dict['break']

    def train(self, epoch, train_dataloader, eval_dataloader, device,
              smoothing, save_mode):
        # set logger
        self.set_logger()

        assert save_mode in ['all', 'best']
        # train for n epoch
        for epoch_i in range(epoch):
            print('[ Epoch', epoch_i, ']')
            # set current epoch
            self.controller.set_epoch(epoch_i + 1)
            # train for on epoch
            state_dict = self.train_epoch(train_dataloader, eval_dataloader,
                                          device, smoothing)

            self.val_epoch(eval_dataloader, device, plot=True)

            if state_dict['step_to_stop']:
                break

        print(
            '[INFO]: Finish Training, ends with %d epoch(s) and %d batches, in total %d training steps.'
            %
            (state_dict['epoch'] - 1, state_dict['batch'], state_dict['step']))

    def get_predictions(self,
                        data_loader,
                        device,
                        max_batches=None,
                        activation=None):

        pred_list = []
        real_list = []

        self.model.eval()
        self.classifier.eval()

        batch_counter = 0

        with torch.no_grad():
            for batch in tqdm(data_loader,
                              desc='  - (Testing)   ',
                              leave=False):

                # get data from dataloader
                feature_1, feature_2, y = map(lambda x: x.to(device), batch)

                # get logits
                logits, attn = self.model(feature_1, feature_2)
                logits = logits.view(logits.shape[0], -1)
                logits = self.classifier(logits)

                # Whether to apply activation function
                if activation != None:
                    pred = activation(logits)
                else:
                    pred = logits
                pred_list += pred.tolist()
                real_list += y.tolist()

                if max_batches != None:
                    batch_counter += 1
                    if batch_counter >= max_batches:
                        break

        return pred_list, real_list
示例#10
0
def main(args):
    print('===> Configuration')
    print(args)

    os.makedirs(args.save, exist_ok=True)
    with open(os.path.join(args.save, "config.txt"), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    cudnn.benchmark = True if args.cuda else False
    device = torch.device("cuda" if args.cuda else "cpu")

    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    # MNIST dataset normalized between [0, 1]
    try:
        with open(args.dataset, 'rb') as f:
            dataset_dict = pickle.load(f)
    except BaseException as e:
        print(str(e.__class__.__name__) + ": " + str(e))
        exit()

    X_train_labeled = dataset_dict["X_train_labeled"]
    y_train_labeled = dataset_dict["y_train_labeled"]
    X_train_unlabeled = dataset_dict["X_train_unlabeled"]
    y_train_unlabeled = dataset_dict["y_train_unlabeled"]
    X_val = dataset_dict["X_val"]
    y_val = dataset_dict["y_val"]
    X_test = dataset_dict["X_test"]
    y_test = dataset_dict["y_test"]

    labeled_dataset = TensorDataset(
        torch.from_numpy(X_train_labeled).float(),
        torch.from_numpy(y_train_labeled).long())
    unlabeled_dataset = TensorDataset(
        torch.from_numpy(X_train_unlabeled).float(),
        torch.from_numpy(y_train_unlabeled).long())
    val_dataset = TensorDataset(
        torch.from_numpy(X_val).float(),
        torch.from_numpy(y_val).long())
    test_dataset = TensorDataset(
        torch.from_numpy(X_test).float(),
        torch.from_numpy(y_test).long())

    NUM_SAMPLES = len(labeled_dataset) + len(unlabeled_dataset)
    NUM_LABELED = len(labeled_dataset)

    labeled_dataloader = DataLoader(labeled_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    num_workers=args.workers,
                                    drop_last=False)
    unlabeled_dataloader = DataLoader(unlabeled_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.workers,
                                      drop_last=False)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers)

    alpha = args.eta * NUM_SAMPLES / NUM_LABELED
    tau = CosineAnnealing(start=1.0, stop=0.5, t_max=args.tw, mode='down')

    model = MnistViVA(z_dim=args.z_dim,
                      hidden_dim=args.hidden,
                      zeta=args.zeta,
                      rho=args.rho,
                      device=device).to(device)
    optimizer = Adam(model.parameters())

    best_val_epoch = 0
    best_val_loss = sys.float_info.max
    best_val_acc = 0.0
    test_acc = 0.0
    early_stop_counter = 0

    if args.resume:
        if os.path.isfile(args.resume):
            print("===> Loading Checkpoint to Resume '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            best_val_epoch = checkpoint['best_epoch']
            best_val_loss = checkpoint['best_val_loss']
            best_val_acc = checkpoint['best_val_acc']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("\t===> Loaded Checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError(
                "\t====> no checkpoint found at '{}'".format(args.resume))

    n_batches = len(labeled_dataloader) + len(unlabeled_dataloader)
    n_unlabeled_per_labeled = len(unlabeled_dataloader) // len(
        labeled_dataloader) + 1

    with tqdm(range(args.start_epoch, args.epochs), desc="Epochs") as nested:
        for epoch in nested:

            # Train
            model.train()
            train_recon_loss = AverageMeter('Train_Recon_Loss')
            train_latent_loss = AverageMeter('Train_Latent_Loss')
            train_label_loss = AverageMeter('Train_Label_Loss')
            train_tsne_loss = AverageMeter('Train_tSNE_Loss')
            train_total_loss = AverageMeter('Train_Total_Loss')
            train_accuracy = AverageMeter('Train_Accuracy')

            labeled_iter = iter(labeled_dataloader)
            unlabeled_iter = iter(unlabeled_dataloader)

            for batch_idx in range(n_batches):

                is_supervised = batch_idx % n_unlabeled_per_labeled == 0
                # get batch from respective dataloader
                if is_supervised:
                    try:
                        data, target = next(labeled_iter)
                        data = data.to(device)
                        target = target.to(device)
                        one_hot_target = one_hot(target, 10)
                    except StopIteration:
                        data, target = next(unlabeled_iter)
                        data = data.to(device)
                        target = target.to(device)
                        one_hot_target = None
                else:
                    data, target = next(unlabeled_iter)
                    data = data.to(device)
                    target = target.to(device)
                    one_hot_target = None

                model.zero_grad()

                recon_loss_sum, y_logits, t_coords, latent_loss_sum, tsne_loss = model(
                    data, one_hot_target, tau.step())
                recon_loss = recon_loss_sum / data.size(0)
                label_loss = F.cross_entropy(y_logits,
                                             target,
                                             reduction='mean')
                latent_loss = latent_loss_sum / data.size(0)

                # Full loss
                total_loss = recon_loss + latent_loss + args.gamma * tsne_loss
                if is_supervised and one_hot_target is not None:
                    total_loss += alpha * label_loss

                assert not np.isnan(
                    total_loss.item()), 'Model diverged with loss = NaN'

                train_recon_loss.update(recon_loss.item())
                train_latent_loss.update(latent_loss.item())
                train_label_loss.update(label_loss.item())
                train_tsne_loss.update(tsne_loss.item())
                train_total_loss.update(total_loss.item())

                total_loss.backward()
                optimizer.step()

                pred = y_logits.argmax(
                    dim=1,
                    keepdim=True)  # get the index of the max log-probability
                train_correct = pred.eq(target.view_as(pred)).sum().item()
                train_accuracy.update(train_correct / data.size(0),
                                      data.size(0))

                if batch_idx % args.log_interval == 0:
                    tqdm.write(
                        'Train Epoch: {} [{}/{} ({:.0f}%)]\t Recon: {:.6f} Latent: {:.6f} t-SNE: {:.6f} Accuracy {:.4f} T {:.6f}'
                        .format(epoch, batch_idx, n_batches,
                                100. * batch_idx / n_batches,
                                train_recon_loss.avg, train_latent_loss.avg,
                                train_tsne_loss.avg, train_accuracy.avg,
                                tau.value))

            tqdm.write(
                '====> Epoch: {} Average train loss - Recon {:.3f} Latent {:.3f} t-SNE {:.6f} Label {:.6f} Accuracy {:.4f}'
                .format(epoch, train_recon_loss.avg, train_latent_loss.avg,
                        train_tsne_loss.avg, train_label_loss.avg,
                        train_accuracy.avg))

            # Validation
            model.eval()

            val_recon_loss = AverageMeter('Val_Recon_Loss')
            val_latent_loss = AverageMeter('Val_Latent_Loss')
            val_label_loss = AverageMeter('Val_Label_Loss')
            val_tsne_loss = AverageMeter('Val_tSNE_Loss')
            val_total_loss = AverageMeter('Val_Total_Loss')
            val_accuracy = AverageMeter('Val_Accuracy')

            with torch.no_grad():
                for i, (data, target) in enumerate(val_loader):
                    data = data.to(device)
                    target = target.to(device)

                    recon_loss_sum, y_logits, t_coords, latent_loss_sum, tsne_loss = model(
                        data, temperature=tau.value)

                    recon_loss = recon_loss_sum / data.size(0)
                    label_loss = F.cross_entropy(y_logits,
                                                 target,
                                                 reduction='mean')
                    latent_loss = latent_loss_sum / data.size(0)

                    # Full loss
                    total_loss = recon_loss + latent_loss + args.gamma * tsne_loss + alpha * label_loss

                    val_recon_loss.update(recon_loss.item())
                    val_latent_loss.update(latent_loss.item())
                    val_label_loss.update(label_loss.item())
                    val_tsne_loss.update(tsne_loss.item())
                    val_total_loss.update(total_loss.item())

                    pred = y_logits.argmax(
                        dim=1, keepdim=True
                    )  # get the index of the max log-probability
                    val_correct = pred.eq(target.view_as(pred)).sum().item()
                    val_accuracy.update(val_correct / data.size(0),
                                        data.size(0))

            tqdm.write(
                '\t Validation loss - Recon {:.3f} Latent {:.3f} t-SNE {:.6f} Label: {:.6f} Accuracy {:.4f}'
                .format(val_recon_loss.avg, val_latent_loss.avg,
                        val_tsne_loss.avg, val_label_loss.avg,
                        val_accuracy.avg))

            is_best = val_accuracy.avg > best_val_acc
            if is_best:
                early_stop_counter = 0
                best_val_epoch = epoch
                best_val_loss = val_total_loss.avg
                best_val_acc = val_accuracy.avg

                test_accuracy = AverageMeter('Test_Accuracy')
                with torch.no_grad():
                    for i, (data, target) in enumerate(test_loader):
                        data = data.to(device)
                        target = target.to(device)

                        _, y_logits, _, _, _ = model(data,
                                                     temperature=tau.value)

                        pred = y_logits.argmax(
                            dim=1, keepdim=True
                        )  # get the index of the max log-probability
                        test_correct = pred.eq(
                            target.view_as(pred)).sum().item()
                        test_accuracy.update(test_correct / data.size(0),
                                             data.size(0))

                test_acc = test_accuracy.avg
                tqdm.write('\t Test Accuracy {:.4f}'.format(test_acc))
                with open(os.path.join(args.save, 'train_result.txt'),
                          'w') as f:
                    f.write('Best Validation Epoch: {}\n'.format(epoch))
                    f.write('Train Recon Loss: {}\n'.format(
                        train_recon_loss.avg))
                    f.write('Train Latent Loss: {}\n'.format(
                        train_latent_loss.avg))
                    f.write('Train tSNE Loss: {}\n'.format(
                        train_tsne_loss.avg))
                    f.write('Train Label Loss: {}\n'.format(
                        train_label_loss.avg))
                    f.write('Train Total Loss: {}\n'.format(
                        train_total_loss.avg))
                    f.write('Train Accuracy: {}\n'.format(train_accuracy.avg))
                    f.write('Val Recon Loss: {}\n'.format(val_recon_loss.avg))
                    f.write('Val Latent Loss: {}\n'.format(
                        val_latent_loss.avg))
                    f.write('Val tSNE Loss: {}\n'.format(val_tsne_loss.avg))
                    f.write('Val Label Loss: {}\n'.format(val_label_loss.avg))
                    f.write('Val Total Loss: {}\n'.format(val_total_loss.avg))
                    f.write('Val Accuracy: {}\n'.format(val_accuracy.avg))
                    f.write('Test Accuracy: {}\n'.format(test_acc))
            else:
                early_stop_counter += 1

            save_checkpoint(
                {
                    'epoch': epoch,
                    'best_epoch': best_val_epoch,
                    'best_val_loss': best_val_loss,
                    'best_val_accuracy': best_val_acc,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                filename=os.path.join(args.save, 'checkpoint.pth'))

            if args.early_stop > 0 and early_stop_counter == args.early_stop:
                tqdm.write(
                    "Early Stop with no improvement: epoch {}".format(epoch))
                break

    print("Training is Completed!")
    print("Best Val Acc: {:.4f} Test Acc: {:.4f}".format(
        best_val_acc, test_acc))
示例#11
0
class Brain:
    def __init__(self, state_shape, n_actions, device, n_workers, epochs,
                 n_iters, epsilon, lr):
        self.state_shape = state_shape
        self.n_actions = n_actions
        self.device = device
        self.n_workers = n_workers
        self.mini_batch_size = 32
        self.epochs = epochs
        self.n_iters = n_iters
        self.initial_epsilon = epsilon
        self.epsilon = self.initial_epsilon
        self.lr = lr

        self.current_policy = Model(self.state_shape,
                                    self.n_actions).to(self.device)

        self.optimizer = Adam(self.current_policy.parameters(),
                              lr=self.lr,
                              eps=1e-5)
        self._schedule_fn = lambda step: max(1.0 - float(step / self.n_iters),
                                             0)
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=self._schedule_fn)

    def get_actions_and_values(self, state, batch=False):
        if not batch:
            state = np.expand_dims(state, 0)
        state = from_numpy(state).byte().permute([0, 3, 1, 2]).to(self.device)
        with torch.no_grad():
            dist, value = self.current_policy(state)
            action = dist.sample()
            log_prob = dist.log_prob(action)
        return action.cpu().numpy(), value.detach().cpu().numpy().squeeze(
        ), log_prob.cpu().numpy()

    def choose_mini_batch(self, states, actions, returns, advs, values,
                          log_probs):
        for worker in range(self.n_workers):
            idxes = np.random.randint(0, states.shape[1], self.mini_batch_size)
            yield states[worker][idxes], actions[worker][idxes], returns[worker][idxes], advs[worker][idxes], \
                  values[worker][idxes], log_probs[worker][idxes]

    def train(self, states, actions, rewards, dones, values, log_probs,
              next_values):
        returns = self.get_gae(rewards, values.copy(), next_values, dones)
        values = np.vstack(
            values)  # .reshape((len(values[0]) * self.n_workers,))
        advs = returns - values
        advs = (advs - advs.mean(1).reshape((-1, 1))) / (advs.std(1).reshape(
            (-1, 1)) + 1e-8)
        for epoch in range(self.epochs):
            for state, action, q_value, adv, old_value, old_log_prob in self.choose_mini_batch(
                    states, actions, returns, advs, values, log_probs):
                state = torch.ByteTensor(state).permute([0, 3, 1,
                                                         2]).to(self.device)
                action = torch.Tensor(action).to(self.device)
                adv = torch.Tensor(adv).to(self.device)
                q_value = torch.Tensor(q_value).to(self.device)
                old_value = torch.Tensor(old_value).to(self.device)
                old_log_prob = torch.Tensor(old_log_prob).to(self.device)

                dist, value = self.current_policy(state)
                entropy = dist.entropy().mean()
                new_log_prob = self.calculate_log_probs(
                    self.current_policy, state, action)
                ratio = (new_log_prob - old_log_prob).exp()
                actor_loss = self.compute_ac_loss(ratio, adv)

                clipped_value = old_value + torch.clamp(
                    value.squeeze() - old_value, -self.epsilon, self.epsilon)
                clipped_v_loss = (clipped_value - q_value).pow(2)
                unclipped_v_loss = (value.squeeze() - q_value).pow(2)
                critic_loss = 0.5 * torch.max(clipped_v_loss,
                                              unclipped_v_loss).mean()

                total_loss = critic_loss + actor_loss - 0.01 * entropy
                self.optimize(total_loss)

        return total_loss.item(), entropy.item(), \
               explained_variance(values.reshape((len(returns[0]) * self.n_workers,)),
                                  returns.reshape((len(returns[0]) * self.n_workers,)))

    def schedule_lr(self):
        self.scheduler.step()

    def schedule_clip_range(self, iter):
        self.epsilon = max(1.0 - float(iter / self.n_iters),
                           0) * self.initial_epsilon

    def optimize(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.current_policy.parameters(), 0.5)
        self.optimizer.step()

    def get_gae(self,
                rewards,
                values,
                next_values,
                dones,
                gamma=0.99,
                lam=0.95):

        returns = [[] for _ in range(self.n_workers)]
        extended_values = np.zeros((self.n_workers, len(rewards[0]) + 1))
        for worker in range(self.n_workers):
            extended_values[worker] = np.append(values[worker],
                                                next_values[worker])
            gae = 0
            for step in reversed(range(len(rewards[worker]))):
                delta = rewards[worker][step] + \
                        gamma * (extended_values[worker][step + 1]) * (1 - dones[worker][step]) \
                        - extended_values[worker][step]
                gae = delta + gamma * lam * (1 - dones[worker][step]) * gae
                returns[worker].insert(0, gae + extended_values[worker][step])

        return np.vstack(
            returns)  # .reshape((len(returns[0]) * self.n_workers,))

    @staticmethod
    def calculate_log_probs(model, states, actions):
        policy_distribution, _ = model(states)
        return policy_distribution.log_prob(actions)

    def compute_ac_loss(self, ratio, adv):
        new_r = ratio * adv
        clamped_r = torch.clamp(ratio, 1 - self.epsilon,
                                1 + self.epsilon) * adv
        loss = torch.min(new_r, clamped_r)
        loss = -loss.mean()
        return loss

    def save_params(self, iteration, running_reward):
        torch.save(
            {
                "current_policy_state_dict": self.current_policy.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "scheduler_state_dict": self.scheduler.state_dict(),
                "iteration": iteration,
                "running_reward": running_reward,
                "clip_range": self.epsilon
            }, "params.pth")

    def load_params(self):
        checkpoint = torch.load("params.pth", map_location=self.device)
        self.current_policy.load_state_dict(
            checkpoint["current_policy_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        iteration = checkpoint["iteration"]
        running_reward = checkpoint["running_reward"]
        self.epsilon = checkpoint["clip_range"]

        return running_reward, iteration

    def set_to_eval_mode(self):
        self.current_policy.eval()
示例#12
0
            # l = loss(logits_, ans_)

            loss_item = l.item()



            _, predicted = torch.max(logits_.data, 1)
            acc_item = 100. * (predicted == ans_).sum().item() / ans_.size(0)

            if loss_item < 1. and acc_item > 90.:
                show_ans(ans_labels)

            logger.info('{} of {}: {} {}%'.format(idx, len(train_dataset), loss_item, acc_item))

            optim.zero_grad()
            l.backward()
            optim.step()

            torch.cuda.empty_cache()

        # show_my_result(image, mask, labels)

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optim.state_dict(),
        'loss': loss,
    }, './ckpts/{}'.format(epoch))


示例#13
0
class LSTMNetworks(nn.Module, ObservableData):
    '''
    Long short term memory(LSTM) networks.
    
    Originally, Long Short-Term Memory(LSTM) networks as a 
    special RNN structure has proven stable and powerful for 
    modeling long-range dependencies.
    
    The Key point of structural expansion is its memory cell 
    which essentially acts as an accumulator of the state information. 
    Every time observed data points are given as new information and 
    input to LSTM's input gate, its information will be accumulated to 
    the cell if the input gate is activated. The past state of cell 
    could be forgotten in this process if LSTM's forget gate is on.
    Whether the latest cell output will be propagated to the final state 
    is further controlled by the output gate.
    
    References:
        - Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using RNN encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078.
        - Malhotra, P., Ramakrishnan, A., Anand, G., Vig, L., Agarwal, P., & Shroff, G. (2016). LSTM-based encoder-decoder for multi-sensor anomaly detection. arXiv preprint arXiv:1607.00148.
        - Zaremba, W., Sutskever, I., & Vinyals, O. (2014). Recurrent neural network regularization. arXiv preprint arXiv:1409.2329.

    '''

    # `bool` that means initialization in this class will be deferred or not.
    __init_deferred_flag = False

    def __init__(
        self,
        computable_loss,
        initializer_f=None,
        optimizer_f=None,
        learning_rate=1e-05,
        seq_len=None,
        hidden_n=200,
        output_n=1,
        dropout_rate=0.5,
        input_adjusted_flag=True,
        observed_activation=torch.nn.Tanh(),
        input_gate_activation=torch.nn.Sigmoid(),
        forget_gate_activation=torch.nn.Sigmoid(),
        output_gate_activation=torch.nn.Sigmoid(),
        hidden_activation=torch.nn.Tanh(),
        output_activation=torch.nn.Tanh(),
        output_layer_flag=True,
        output_no_bias_flag=False,
        output_nn=None,
        ctx="cpu",
        regularizatable_data_list=[],
        not_init_flag=False,
    ):
        '''
        Init.

        Args:
            computable_loss:                is-a `ComputableLoss` or `mxnet.gluon.loss`.
            initializer:                    is-a `mxnet.initializer.Initializer` for parameters of model. If `None`, it is drawing from the Xavier distribution.
            batch_size:                     `int` of batch size of mini-batch.
            learning_rate:                  `float` of learning rate.
            learning_attenuate_rate:        `float` of attenuate the `learning_rate` by a factor of this value every `attenuate_epoch`.
            attenuate_epoch:                `int` of attenuate the `learning_rate` by a factor of `learning_attenuate_rate` every `attenuate_epoch`.
                                            

            seq_len:                        `int` of the length of sequences.
                                            This means refereed maxinum step `t` in feedforward.
                                            If `0`, this model will reference all series elements included 
                                            in observed data points.
                                            If not `0`, only first sequence will be observed by this model 
                                            and will be feedfowarded as feature points.
                                            This parameter enables you to build this class as `Decoder` in
                                            Sequence-to-Sequence(Seq2seq) scheme.

            hidden_n:                       `int` of the number of units in hidden layer.
            output_n:                       `int` of the nuber of units in output layer.
            dropout_rate:                   `float` of dropout rate.
            input_adjusted_flag:            `bool` of flag that means this class will adjusted observed data points by normalization.
            observed_activation:            `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` 
                                            that activates observed data points.

            optimizer_name:                 `str` of name of optimizer.

            input_gate_activation:          `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in input gate.
            forget_gate_activation:         `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in forget gate.
            output_gate_activation:         `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in output gate.
            hidden_activation:              `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in hidden layer.
            output_activation:              `act_type` in `mxnet.ndarray.Activation` or `mxnet.symbol.Activation` in output layer.
                                            If this value is `identity`, the activation function equivalents to the identity function.

            output_layer_flag:              `bool` that means this class has output layer or not.
            output_no_bias_flag:            `bool` for using bias or not in output layer(last hidden layer).
            output_nn:               is-a `NNHybrid` as output layers.
                                            If not `None`, `output_layer_flag` and `output_no_bias_flag` will be ignored.

            ctx:                            `mx.cpu()` or `mx.gpu()`.
            hybridize_flag:                  Call `mxnet.gluon.HybridBlock.hybridize()` or not.
            regularizatable_data_list:           `list` of `Regularizatable`.
            scale:                          `float` of scaling factor for initial parameters.
        '''
        if isinstance(computable_loss, ComputableLoss) is False and isinstance(
                computable_loss, nn.modules.loss._Loss) is False:
            raise TypeError(
                "The type of `computable_loss` must be `ComputableLoss` or `gluon.loss.Loss`."
            )

        super(LSTMNetworks, self).__init__()
        self.initializer_f = initializer_f
        self.optimizer_f = optimizer_f
        self.__not_init_flag = not_init_flag

        if dropout_rate > 0.0:
            self.dropout_forward = nn.Dropout(p=dropout_rate)
        else:
            self.dropout_forward = None

        self.__observed_activation = observed_activation
        self.__input_gate_activation = input_gate_activation
        self.__forget_gate_activation = forget_gate_activation
        self.__output_gate_activation = output_gate_activation
        self.__hidden_activation = hidden_activation
        self.__output_activation = output_activation
        self.__output_layer_flag = output_layer_flag

        self.__computable_loss = computable_loss
        self.__learning_rate = learning_rate
        self.__hidden_n = hidden_n
        self.__output_n = output_n
        self.__dropout_rate = dropout_rate
        self.__input_adjusted_flag = input_adjusted_flag

        for v in regularizatable_data_list:
            if isinstance(v, RegularizatableData) is False:
                raise TypeError(
                    "The type of values of `regularizatable_data_list` must be `Regularizatable`."
                )
        self.__regularizatable_data_list = regularizatable_data_list

        self.__ctx = ctx

        logger = getLogger("accelbrainbase")
        self.__logger = logger

        self.__input_dim = None
        self.__input_seq_len = None

        self.__output_layer_flag = output_layer_flag
        self.__output_no_bias_flag = output_no_bias_flag
        self.__output_nn = output_nn
        self.seq_len = seq_len

        self.epoch = 0
        self.__loss_list = []

    def initialize_params(self, input_dim, input_seq_len):
        '''
        Initialize params.

        Args:
            input_dim:      The number of units in input layer.
        '''
        if self.__input_dim is not None:
            return
        self.__input_dim = input_dim
        self.__input_seq_len = input_seq_len

        if self.__not_init_flag is False:
            if self.init_deferred_flag is False:
                self.observed_fc = nn.Linear(
                    input_dim,
                    self.__hidden_n * 4,
                    bias=False,
                )
                if self.initializer_f is None:
                    self.observed_fc.weight = torch.nn.init.xavier_normal_(
                        self.observed_fc.weight, gain=1.0)
                else:
                    self.observed_fc.weight = self.initializer_f(
                        self.observed_fc.weight)

                self.hidden_fc = nn.Linear(
                    self.__hidden_n,
                    self.__hidden_n * 4,
                )
                if self.initializer_f is None:
                    self.hidden_fc.weight = torch.nn.init.xavier_normal_(
                        self.hidden_fc.weight, gain=1.0)
                else:
                    self.hidden_fc.weight = self.initializer_f(
                        self.observed_fc.weight)

                self.input_gate_fc = nn.Linear(
                    self.__hidden_n,
                    self.__hidden_n,
                    bias=False,
                )
                self.forget_gate_fc = nn.Linear(
                    self.__hidden_n,
                    self.__hidden_n,
                    bias=False,
                )
                self.output_gate_fc = nn.Linear(
                    self.__hidden_n,
                    self.__hidden_n,
                    bias=False,
                )

            self.output_fc = None
            self.output_nn = None
            if self.__output_layer_flag is True and self.__output_nn is None:
                if self.__output_no_bias_flag is True:
                    use_bias = False
                else:
                    use_bias = True

                # Different from mxnet version.
                self.output_fc = nn.Linear(
                    self.__hidden_n * self.__input_seq_len,
                    self.__output_n * self.__input_seq_len,
                    bias=use_bias)
                self.__output_dim = self.__output_n
            elif self.__output_nn is not None:
                self.output_nn = self.__output_nn
                self.__output_dim = self.output_nn.units_list[-1]
            else:
                self.__output_dim = self.__hidden_n

        self.to(self.__ctx)
        if self.init_deferred_flag is False:
            if self.__not_init_flag is False:
                if self.optimizer_f is None:
                    self.optimizer = Adam(
                        self.parameters(),
                        lr=self.__learning_rate,
                    )
                else:
                    self.optimizer = self.optimizer_f(self.parameters(), )

    def learn(self, iteratable_data):
        '''
        Learn the observed data points
        for vector representation of the input time-series.

        Args:
            iteratable_data:     is-a `IteratableData`.

        '''
        if isinstance(iteratable_data, IteratableData) is False:
            raise TypeError(
                "The type of `iteratable_data` must be `IteratableData`.")

        self.__loss_list = []
        learning_rate = self.__learning_rate
        try:
            epoch = self.epoch
            iter_n = 0
            for batch_observed_arr, batch_target_arr, test_batch_observed_arr, test_batch_target_arr in iteratable_data.generate_learned_samples(
            ):
                self.__batch_size = batch_observed_arr.shape[0]
                self.__seq_len = batch_observed_arr.shape[1]
                self.initialize_params(input_dim=batch_observed_arr.reshape(
                    self.__batch_size * self.__seq_len, -1).shape[-1],
                                       input_seq_len=self.__seq_len)
                if self.output_nn is not None:
                    if hasattr(self.output_nn, "optimizer") is False:
                        _ = self.inference(batch_observed_arr)

                self.optimizer.zero_grad()
                if self.output_nn is not None:
                    self.output_nn.optimizer.zero_grad()

                # rank-3
                pred_arr = self.inference(batch_observed_arr)
                loss = self.compute_loss(pred_arr, batch_target_arr)
                loss.backward()

                if self.output_nn is not None:
                    self.output_nn.optimizer.step()
                self.optimizer.step()
                self.regularize()

                if (iter_n + 1) % int(
                        iteratable_data.iter_n / iteratable_data.epochs) == 0:
                    with torch.inference_mode():
                        # rank-3
                        test_pred_arr = self.inference(test_batch_observed_arr)

                        test_loss = self.compute_loss(test_pred_arr,
                                                      test_batch_target_arr)

                    _loss = loss.to('cpu').detach().numpy().copy()
                    _test_loss = test_loss.to('cpu').detach().numpy().copy()
                    self.__loss_list.append((_loss, _test_loss))
                    self.__logger.debug("Epochs: " + str(epoch + 1) +
                                        " Train loss: " + str(_loss) +
                                        " Test loss: " + str(_test_loss))
                    epoch += 1
                iter_n += 1

        except KeyboardInterrupt:
            self.__logger.debug("Interrupt.")

        self.epoch = epoch

        self.__logger.debug("end. ")

    def inference(self, observed_arr):
        '''
        Inference the feature points to reconstruct the time-series.

        Args:
            observed_arr:           rank-3 array like or sparse matrix as the observed data points.

        Returns:
            `mxnet.ndarray` of inferenced feature points.
        '''
        return self(observed_arr)

    def compute_loss(self, pred_arr, labeled_arr):
        '''
        Compute loss.

        Args:
            pred_arr:       `mxnet.ndarray` or `mxnet.symbol`.
            labeled_arr:    `mxnet.ndarray` or `mxnet.symbol`.

        Returns:
            loss.
        '''
        return self.__computable_loss(pred_arr, labeled_arr)

    def regularize(self):
        '''
        Regularization.
        '''
        if len(self.__regularizatable_data_list) > 0:
            params_dict = self.extract_learned_dict()
            for regularizatable in self.__regularizatable_data_list:
                params_dict = regularizatable.regularize(params_dict)

            for k, params in params_dict.items():
                self.load_state_dict({k: params}, strict=False)

    def extract_learned_dict(self):
        '''
        Extract (pre-) learned parameters.

        Returns:
            `dict` of the parameters.
        '''
        params_dict = {}
        for k in self.state_dict().keys():
            params_dict.setdefault(k, self.state_dict()[k])

        return params_dict

    def extract_feature_points(self):
        '''
        Extract the activities in hidden layer and reset it, 
        considering this method will be called per one cycle in instances of time-series.

        Returns:
            The `mxnet.ndarray` of array like or sparse matrix of feature points or virtual visible observed data points.
        '''
        return self.feature_points_arr

    def forward(self, x):
        '''
        Forward with Gluon API.

        Args:
            x:      `mxnet.ndarray` of observed data points.
        
        Returns:
            `mxnet.ndarray` or `mxnet.symbol` of inferenced feature points.
        '''
        self.__batch_size = x.shape[0]
        self.__seq_len = x.shape[1]
        x = x.reshape(self.__batch_size, self.__seq_len, -1)
        self.initialize_params(input_dim=x.shape[2],
                               input_seq_len=self.__seq_len)

        hidden_activity_arr = self.hidden_forward_propagate(x)

        if self.__dropout_rate > 0:
            hidden_activity_arr = self.dropout_forward(hidden_activity_arr)
        self.feature_points_arr = hidden_activity_arr

        if self.output_nn is not None:
            pred_arr = self.output_nn(hidden_activity_arr)
            return pred_arr
        if self.__output_layer_flag is True:
            # rank-3
            pred_arr = self.output_forward_propagate(hidden_activity_arr)
            return pred_arr
        else:
            return hidden_activity_arr

    def hidden_forward_propagate(self, observed_arr):
        '''
        Forward propagation in LSTM gate.

        Args:
            F:                      `mxnet.ndarray` or `mxnet.symbol`.
            observed_arr:           rank-3 tensor of observed data points.
        
        Returns:
            Predicted data points.
        '''
        pred_arr = None

        hidden_activity_arr = torch.zeros((self.__batch_size, self.__hidden_n),
                                          dtype=torch.float32)
        hidden_activity_arr = hidden_activity_arr.to(self.__ctx)
        cec_activity_arr = torch.zeros((self.__batch_size, self.__hidden_n),
                                       dtype=torch.float32)
        cec_activity_arr = cec_activity_arr.to(self.__ctx)

        if self.seq_len is not None:
            cycle_n = self.seq_len
        else:
            cycle_n = self.__seq_len

        for cycle in range(cycle_n):
            if cycle == 0:
                if observed_arr[:, cycle:cycle + 1].shape[1] != 0:
                    hidden_activity_arr, cec_activity_arr = self.__lstm_forward(
                        observed_arr[:, cycle:cycle + 1], hidden_activity_arr,
                        cec_activity_arr)
                    skip_flag = False
                else:
                    skip_flag = True
            else:
                if observed_arr.shape[1] > 1:
                    x_arr = observed_arr[:, cycle:cycle + 1]
                else:
                    x_arr = torch.unsqueeze(pred_arr[:, -1], axis=1)

                if x_arr.shape[1] != 0:
                    hidden_activity_arr, cec_activity_arr = self.__lstm_forward(
                        x_arr, hidden_activity_arr, cec_activity_arr)
                    skip_flag = False
                else:
                    skip_flag = True

            if skip_flag is False:
                add_arr = torch.unsqueeze(hidden_activity_arr, axis=1)
                if pred_arr is None:
                    pred_arr = add_arr
                else:
                    pred_arr = torch.cat((pred_arr, add_arr), dim=1)

        return pred_arr

    def __lstm_forward(self, observed_arr, hidden_activity_arr,
                       cec_activity_arr):
        '''
        Forward propagate in LSTM gate.
        
        Args:
            F:                      `mxnet.ndarray` or `mxnet.symbol`.
            observed_arr:           rank-2 tensor of observed data points.
            hidden_activity_arr:    rank-2 tensor of activities in hidden layer.
            cec_activity_arr:       rank-2 tensor of activities in the constant error carousel.
        
        Returns:
            Tuple data.
            - rank-2 tensor of activities in hidden layer,
            - rank-2 tensor of activities in LSTM gate.
        '''
        if len(observed_arr.shape) == 3:
            observed_arr = observed_arr[:, 0]

        if self.__input_adjusted_flag is True:
            observed_arr = torch.div(observed_arr,
                                     torch.sum(torch.ones_like(observed_arr)))

        observed_lstm_matrix = self.observed_fc(observed_arr)

        # using bias
        hidden_lstm_matrix = self.hidden_fc(hidden_activity_arr)
        lstm_matrix = observed_lstm_matrix + hidden_lstm_matrix

        given_activity_arr = lstm_matrix[:, :self.__hidden_n]
        input_gate_activity_arr = lstm_matrix[:,
                                              self.__hidden_n:self.__hidden_n *
                                              2]
        forget_gate_activity_arr = lstm_matrix[:, self.__hidden_n *
                                               2:self.__hidden_n * 3]
        output_gate_activity_arr = lstm_matrix[:, self.__hidden_n *
                                               3:self.__hidden_n * 4]

        # no bias
        _input_gate_activity_arr = self.input_gate_fc(cec_activity_arr)
        input_gate_activity_arr = input_gate_activity_arr + _input_gate_activity_arr
        # no bias
        _forget_gate_activity_arr = self.forget_gate_fc(cec_activity_arr)
        forget_gate_activity_arr = forget_gate_activity_arr + _forget_gate_activity_arr
        given_activity_arr = self.__observed_activation(given_activity_arr)
        input_gate_activity_arr = self.__input_gate_activation(
            input_gate_activity_arr)
        forget_gate_activity_arr = self.__forget_gate_activation(
            forget_gate_activity_arr)

        # rank-2
        _cec_activity_arr = torch.mul(
            given_activity_arr, input_gate_activity_arr) + torch.mul(
                forget_gate_activity_arr, cec_activity_arr)

        # no bias
        _output_gate_activity_arr = self.output_gate_fc(_cec_activity_arr)

        output_gate_activity_arr = output_gate_activity_arr + _output_gate_activity_arr
        output_gate_activity_arr = self.__output_gate_activation(
            output_gate_activity_arr)

        # rank-2
        _hidden_activity_arr = torch.mul(
            output_gate_activity_arr,
            self.__hidden_activation(_cec_activity_arr))

        return (_hidden_activity_arr, _cec_activity_arr)

    def output_forward_propagate(self, pred_arr):
        '''
        Forward propagation in output layer.
        
        Args:
            F:                   `mxnet.ndarray` or `mxnet.symbol`.
            pred_arr:            rank-3 tensor of predicted data points.

        Returns:
            rank-3 tensor of propagated data points.
        '''
        if self.__output_layer_flag is False:
            return pred_arr

        batch_size = pred_arr.shape[0]
        seq_len = pred_arr.shape[1]
        # Different from mxnet version.
        pred_arr = self.output_fc(torch.reshape(pred_arr, (batch_size, -1)))
        if self.__output_activation == "identity_adjusted":
            pred_arr = torch.div(pred_arr,
                                 torch.sum(torch.ones_like(pred_arr)))
        elif self.__output_activation != "identity":
            pred_arr = self.__output_activation(pred_arr)
        pred_arr = torch.reshape(pred_arr, (batch_size, seq_len, -1))
        return pred_arr

    def save_parameters(self, filename):
        '''
        Save parameters to files.

        Args:
            filename:       File name.
        '''
        torch.save(
            {
                'epoch': self.epoch,
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': self.loss_arr,
                'input_dim': self.__input_dim,
                'input_seq_len': self.__input_seq_len,
            }, filename)

    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        checkpoint = torch.load(filename)
        self.initialize_params(
            input_dim=checkpoint["input_dim"],
            input_seq_len=checkpoint["input_seq_len"],
        )
        self.load_state_dict(checkpoint['model_state_dict'], strict=strict)
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.epoch = checkpoint['epoch']
        self.__loss_list = checkpoint['loss'].tolist()
        if ctx is not None:
            self.to(ctx)
            self.__ctx = ctx

    def set_readonly(self, value):
        ''' setter for losses. '''
        raise TypeError("This property must be read-only.")

    __loss_list = []

    def get_loss_arr(self):
        ''' getter for losses. '''
        return np.array(self.__loss_list)

    loss_arr = property(get_loss_arr, set_readonly)

    def get_output_dim(self):
        return self.__output_dim

    output_dim = property(get_output_dim, set_readonly)

    def get_init_deferred_flag(self):
        ''' getter for `bool` that means initialization in this class will be deferred or not.'''
        return self.__init_deferred_flag

    def set_init_deferred_flag(self, value):
        ''' setter for `bool` that means initialization in this class will be deferred or not. '''
        self.__init_deferred_flag = value

    init_deferred_flag = property(get_init_deferred_flag,
                                  set_init_deferred_flag)