Ejemplo n.º 1
0
def build_decoder(n_vocabs):
    model = Decoder(model_name=C.decoder_model,
                    n_layers=C.decoder_n_layers,
                    encoder_size=C.encoder_output_size,
                    embedding_size=C.embedding_size,
                    embedding_scale=C.embedding_scale,
                    hidden_size=C.decoder_hidden_size,
                    attn_size=C.decoder_attn_size,
                    output_size=n_vocabs,
                    embedding_dropout=C.embedding_dropout,
                    dropout=C.decoder_dropout,
                    out_dropout=C.decoder_out_dropout)
    model = model.to(C.device)
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=C.decoder_learning_rate,
                                 weight_decay=C.decoder_weight_decay,
                                 amsgrad=C.decoder_use_amsgrad)
    lambda_reg = torch.autograd.Variable(torch.tensor(0.001),
                                         requires_grad=True)
    lambda_reg = lambda_reg.to(C.device)

    decoder = {
        'model': model,
        'loss': loss,
        'optimizer': optimizer,
        'lambda_reg': lambda_reg,
    }
    return decoder
Ejemplo n.º 2
0
def train():
    torch.backends.cudnn.benchmark = True

    _, dataloader = create_dataloader(config.IMG_DIR + "/train", config.MESH_DIR + "/train",
                                            batch_size=config.BATCH_SIZE, used_layers=config.USED_LAYERS,
                                            img_size=config.IMAGE_SIZE, map_size=config.MAP_SIZE,
                                            augment=config.AUGMENT, workers=config.NUM_WORKERS,
                                            pin_memory=config.PIN_MEMORY, shuffle=True)

    in_channels = num_channels(config.USED_LAYERS)
    encoder = Encoder(in_channels=in_channels)
    decoder = Decoder(num_classes=config.NUM_CLASSES+1)
    encoder.apply(init_weights)
    decoder.apply(init_weights)
    encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()),
                                      lr=config.ENCODER_LEARNING_RATE,
                                      betas=config.BETAS)
    decoder_solver = torch.optim.Adam(decoder.parameters(),
                                      lr=config.DECODER_LEARNING_RATE,
                                      betas=config.BETAS)
    encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(encoder_solver,
                                                                milestones=config.ENCODER_LR_MILESTONES,
                                                                gamma=config.GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(decoder_solver,
                                                                milestones=config.DECODER_LR_MILESTONES,
                                                                gamma=config.GAMMA)
    encoder = encoder.to(config.DEVICE)
    decoder = decoder.to(config.DEVICE)

    loss_fn = LossFunction()

    init_epoch = 0
    if config.CHECKPOINT_FILE and config.LOAD_MODEL:
        init_epoch, encoder, decoder = load_checkpoint(encoder, decoder, config.CHECKPOINT_FILE, config.DEVICE)

    output_dir = os.path.join(config.OUT_PATH, re.sub("[^0-9a-zA-Z]+", "-", dt.now().isoformat()))

    for epoch_idx in range(init_epoch, config.NUM_EPOCHS):
        encoder.train()
        decoder.train()
        train_one_epoch(encoder, decoder, dataloader, loss_fn, encoder_solver, decoder_solver, epoch_idx)
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()

        if config.TEST:
            test(encoder, decoder)
        if config.SAVE_MODEL:
            save_checkpoint(epoch_idx, encoder, decoder, output_dir)

    if not config.TEST:
        test(encoder, decoder)
    if not config.SAVE_MODEL:
        save_checkpoint(config.NUM_EPOCHS - 1, encoder, decoder, output_dir)
Ejemplo n.º 3
0
class DeepLabV3Plus(nn.Module):
    def __init__(self, num_classes, in_channels=3, backbone='xception', pretrained=True,
                 output_stride=16, freeze_bn=False, **_):
        super(DeepLabV3Plus, self).__init__()
        assert ('xception' or 'resnet' in backbone)
        self.backbone, low_level_channels = getBackBone(backbone, in_channels=in_channels, output_stride=output_stride,
                                                        pretrained=pretrained)

        self.ASSP = ASSP(in_channels=2048, output_stride=output_stride)

        self.decoder = Decoder(low_level_channels, num_classes)

        if freeze_bn:
            self.freeze_bn()

    def forward(self, x):
        H, W = x.size(2), x.size(3)
        x, low_level_features = self.backbone(x)
        x = self.ASSP(x)
        x = self.decoder(x, low_level_features)
        x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
        return x

    # Two functions to yield the parameters of the backbone
    # & Decoder / ASSP to use differentiable learning rates
    # FIXME: in xception, we use the parameters from xception and not aligned xception
    # better to have higher lr for this backbone

    def get_backbone_params(self):
        return self.backbone.parameters()

    def get_decoder_params(self):
        return chain(self.ASSP.parameters(), self.decoder.parameters())

    def freeze_bn(self):
        for module in self.modules():
            if isinstance(module, nn.BatchNorm2d): module.eval()
def train(model_config, train_config):
    mode = 'train'
    dataset = ShakespeareModern(train_shakespeare_path,
                                test_shakespeare_path,
                                train_modern_path,
                                test_modern_path,
                                mode=mode)
    dataloader = DataLoader(dataset,
                            batch_size=train_config['batch_size'],
                            shuffle=False)
    vocab = dataset.vocab
    max_length = dataset.domain_A_max_len
    encoder = Encoder(model_config['embedding_size'],
                      model_config['hidden_dim'],
                      dataset.vocab.num_words,
                      batch_size=train_config['batch_size']).cuda()
    # print(dataset.domain_A_max_len)
    decoder = Decoder(model_config['embedding_size'],
                      model_config['hidden_dim'],
                      dataset.vocab.num_words,
                      max_length,
                      batch_size=train_config['batch_size']).cuda()

    criterion = nn.NLLLoss().cuda()
    encoder_optimizer = torch.optim.SGD(encoder.parameters(),
                                        lr=train_config['base_lr'])
    decoder_optimizer = torch.optim.SGD(decoder.parameters(),
                                        lr=train_config['base_lr'])

    for epoch in range(train_config['num_epochs']):
        for idx, (s, s_addn_feats, m,
                  m_addn_feats) in tqdm(enumerate(dataloader)):
            input_tensor = s.transpose(0, 1).cuda()
            target_tensor = m.transpose(0, 1).cuda()

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            input_length = input_tensor.size(0)
            target_length = target_tensor.size(0)

            encoder_outputs = torch.zeros(max_length, encoder.hidden_size)

            loss = 0
            print('ip', input_tensor.size())
            encoder_output, encoder_hidden = encoder(input_tensor)
            # encoder_outputs = encoder_output[0, 0]

            decoder_input = torch.empty(
                (train_config['batch_size'],
                 1)).fill_(SOS_token).type(torch.LongTensor).cuda()
            print(decoder_input.size())

            decoder_hidden = encoder_output[-1]
            print('dec hid', decoder_hidden.size(), type(decoder_hidden))

            while decoder_input:
                decoder.hidden = decoder_hidden
                decoder_input, decoder_hidden = decoder(
                    decoder_input, encoder_output)

            loss += criterion(decoder_output, target_tensor[di])
            loss.backward()

            encoder_optimizer.step()
            decoder_optimizer.step()

            if idx % 100 == 0:
                print(
                    '\tepoch [{}/{}], iter: {}, s_loss: {:.4f}, m_loss: {:.4f}, preds: s: {}, {}, m: {}, {}'
                    .format(epoch + 1, train_config['num_epochs'], idx,
                            s_loss.item(), m_loss.item(), s_output.item(),
                            round(s_output.item()), m_output.item(),
                            round(m_output.item())))

        print('\tepoch [{}/{}]'.format(epoch + 1, train_config['num_epochs']))

        return loss.item() / target_length
Ejemplo n.º 5
0
def train_net(cfg):
    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
    torch.backends.cudnn.benchmark = True

    # Set up data augmentation
    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
    train_transforms = utils.data_transforms.Compose([
        utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(
            cfg.TRAIN.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS,
                                          cfg.TRAIN.CONTRAST,
                                          cfg.TRAIN.SATURATION),
        utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.RandomFlip(),
        utils.data_transforms.RandomPermuteRGB(),
        utils.data_transforms.ToTensor(),
    ])
    val_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    # Set up data loader
    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING,
            train_transforms),
        batch_size=cfg.CONST.BATCH_SIZE,
        num_workers=cfg.TRAIN.NUM_WORKER,
        pin_memory=True,
        shuffle=True,
        drop_last=True)
    val_data_loader = torch.utils.data.DataLoader(
        dataset=val_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING,
            val_transforms),
        batch_size=1,
        num_workers=1,
        pin_memory=True,
        shuffle=False)

    # Set up networks
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    refiner = Refiner(cfg)
    merger = Merger(cfg)
    print('[DEBUG] %s Parameters in Encoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(encoder)))
    print('[DEBUG] %s Parameters in Decoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(decoder)))
    print('[DEBUG] %s Parameters in Refiner: %d.' %
          (dt.now(), utils.network_utils.count_parameters(refiner)))
    print('[DEBUG] %s Parameters in Merger: %d.' %
          (dt.now(), utils.network_utils.count_parameters(merger)))

    # Initialize weights of networks
    encoder.apply(utils.network_utils.init_weights)
    decoder.apply(utils.network_utils.init_weights)
    refiner.apply(utils.network_utils.init_weights)
    merger.apply(utils.network_utils.init_weights)

    # Set up solver
    if cfg.TRAIN.POLICY == 'adam':
        encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                 encoder.parameters()),
                                          lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        decoder_solver = torch.optim.Adam(decoder.parameters(),
                                          lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        refiner_solver = torch.optim.Adam(refiner.parameters(),
                                          lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                          betas=cfg.TRAIN.BETAS)
        merger_solver = torch.optim.Adam(merger.parameters(),
                                         lr=cfg.TRAIN.MERGER_LEARNING_RATE,
                                         betas=cfg.TRAIN.BETAS)
    elif cfg.TRAIN.POLICY == 'sgd':
        encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                encoder.parameters()),
                                         lr=cfg.TRAIN.ENCODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        decoder_solver = torch.optim.SGD(decoder.parameters(),
                                         lr=cfg.TRAIN.DECODER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        refiner_solver = torch.optim.SGD(refiner.parameters(),
                                         lr=cfg.TRAIN.REFINER_LEARNING_RATE,
                                         momentum=cfg.TRAIN.MOMENTUM)
        merger_solver = torch.optim.SGD(merger.parameters(),
                                        lr=cfg.TRAIN.MERGER_LEARNING_RATE,
                                        momentum=cfg.TRAIN.MOMENTUM)
    else:
        raise Exception('[FATAL] %s Unknown optimizer %s.' %
                        (dt.now(), cfg.TRAIN.POLICY))

    # Set up learning rate scheduler to decay learning rates dynamically
    encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        encoder_solver,
        milestones=cfg.TRAIN.ENCODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        decoder_solver,
        milestones=cfg.TRAIN.DECODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        refiner_solver,
        milestones=cfg.TRAIN.REFINER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)
    merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        merger_solver,
        milestones=cfg.TRAIN.MERGER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA)

    if torch.cuda.is_available():
        encoder = torch.nn.DataParallel(encoder).cuda()
        decoder = torch.nn.DataParallel(decoder).cuda()
        refiner = torch.nn.DataParallel(refiner).cuda()
        merger = torch.nn.DataParallel(merger).cuda()

    # Set up loss functions
    bce_loss = torch.nn.BCELoss()

    # Load pretrained model if exists
    init_epoch = 0
    best_iou = -1
    best_epoch = -1
    if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN:
        print('[INFO] %s Recovering from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))
        checkpoint = torch.load(cfg.CONST.WEIGHTS)
        init_epoch = checkpoint['epoch_idx']
        best_iou = checkpoint['best_iou']
        best_epoch = checkpoint['best_epoch']

        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        if cfg.NETWORK.USE_REFINER:
            refiner.load_state_dict(checkpoint['refiner_state_dict'])
        if cfg.NETWORK.USE_MERGER:
            merger.load_state_dict(checkpoint['merger_state_dict'])

        print('[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' \
                 % (dt.now(), init_epoch, best_iou, best_epoch))

    # Summary writer for TensorBoard
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat())
    log_dir = output_dir % 'logs'
    ckpt_dir = output_dir % 'checkpoints'
    train_writer = SummaryWriter(os.path.join(log_dir, 'train'))
    val_writer = SummaryWriter(os.path.join(log_dir, 'test'))

    # Training loop
    for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
        # Tick / tock
        epoch_start_time = time()

        # Batch average meterics
        batch_time = utils.network_utils.AverageMeter()
        data_time = utils.network_utils.AverageMeter()
        encoder_losses = utils.network_utils.AverageMeter()
        refiner_losses = utils.network_utils.AverageMeter()

        # Adjust learning rate
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()
        refiner_lr_scheduler.step()
        merger_lr_scheduler.step()

        # switch models to training mode
        encoder.train()
        decoder.train()
        merger.train()
        refiner.train()

        batch_end_time = time()
        n_batches = len(train_data_loader)
        for batch_idx, (taxonomy_names, sample_names, rendering_images,
                        ground_truth_volumes) in enumerate(train_data_loader):
            # Measure data time
            data_time.update(time() - batch_end_time)

            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)
            ground_truth_volumes = utils.network_utils.var_or_cuda(
                ground_truth_volumes)

            # Train the encoder, decoder, refiner, and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volumes = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volumes = merger(raw_features, generated_volumes)
            else:
                generated_volumes = torch.mean(generated_volumes, dim=1)
            encoder_loss = bce_loss(generated_volumes,
                                    ground_truth_volumes) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volumes = refiner(generated_volumes)
                refiner_loss = bce_loss(generated_volumes,
                                        ground_truth_volumes) * 10
            else:
                refiner_loss = encoder_loss

            # Gradient decent
            encoder.zero_grad()
            decoder.zero_grad()
            refiner.zero_grad()
            merger.zero_grad()

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                encoder_loss.backward(retain_graph=True)
                refiner_loss.backward()
            else:
                encoder_loss.backward()

            encoder_solver.step()
            decoder_solver.step()
            refiner_solver.step()
            merger_solver.step()

            # Append loss to average metrics
            encoder_losses.update(encoder_loss.item())
            refiner_losses.update(refiner_loss.item())
            # Append loss to TensorBoard
            n_itr = epoch_idx * n_batches + batch_idx
            train_writer.add_scalar('EncoderDecoder/BatchLoss',
                                    encoder_loss.item(), n_itr)
            train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(),
                                    n_itr)

            # Tick / tock
            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            print('[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' % \
                (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, \
                    batch_time.val, data_time.val, encoder_loss.item(), refiner_loss.item()))

        # Append epoch loss to TensorBoard
        train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg,
                                epoch_idx + 1)
        train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg,
                                epoch_idx + 1)

        # Tick / tock
        epoch_end_time = time()
        print('[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' %
            (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, \
                encoder_losses.avg, refiner_losses.avg))

        # Update Rendering Views
        if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING:
            n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING)
            train_data_loader.dataset.set_n_views_rendering(n_views_rendering)
            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' % \
                (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering))

        # Validate the training models
        iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader,
                       val_writer, encoder, decoder, refiner, merger)

        # Save weights to file
        if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            utils.network_utils.save_checkpoints(cfg, \
                    os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth' % (epoch_idx + 1)), \
                    epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \
                    refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)
        if iou > best_iou:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            best_iou = iou
            best_epoch = epoch_idx + 1
            utils.network_utils.save_checkpoints(cfg, \
                    os.path.join(ckpt_dir, 'best-ckpt.pth'), \
                    epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver, \
                    refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)

    # Close SummaryWriter for TensorBoard
    train_writer.close()
    val_writer.close()
Ejemplo n.º 6
0
def train_net(cfg):
    # Set up data augmentation
    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W
    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W
    train_transforms = utils.data_transforms.Compose([
        utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(
            cfg.TRAIN.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS,
                                          cfg.TRAIN.CONTRAST,
                                          cfg.TRAIN.SATURATION),
        utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.RandomFlip(),
        utils.data_transforms.RandomPermuteRGB(),
        utils.data_transforms.ToTensor(),
    ])
    val_transforms = utils.data_transforms.Compose([
        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),
        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),
        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN,
                                        std=cfg.DATASET.STD),
        utils.data_transforms.ToTensor(),
    ])

    # Set up data loader
    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TRAIN_DATASET](cfg)
    val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[
        cfg.DATASET.TEST_DATASET](cfg)
    train_data_loader = paddle.io.DataLoader(
        dataset=train_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING,
            train_transforms),
        batch_size=cfg.CONST.BATCH_SIZE,
        #num_workers=0  , # cfg.TRAIN.NUM_WORKER>0时报错,因为dev/shm/太小  https://blog.csdn.net/ctypyb2002/article/details/107914643
        #pin_memory=True,
        use_shared_memory=False,
        shuffle=True,
        drop_last=True)
    val_data_loader = paddle.io.DataLoader(
        dataset=val_dataset_loader.get_dataset(
            utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING,
            val_transforms),
        batch_size=1,
        #num_workers=1,
        #pin_memory=True,
        shuffle=False)

    # Set up networks # paddle.Model prepare fit save
    encoder = Encoder(cfg)
    decoder = Decoder(cfg)
    merger = Merger(cfg)
    refiner = Refiner(cfg)
    print('[DEBUG] %s Parameters in Encoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(encoder)))
    print('[DEBUG] %s Parameters in Decoder: %d.' %
          (dt.now(), utils.network_utils.count_parameters(decoder)))
    print('[DEBUG] %s Parameters in Merger: %d.' %
          (dt.now(), utils.network_utils.count_parameters(merger)))
    print('[DEBUG] %s Parameters in Refiner: %d.' %
          (dt.now(), utils.network_utils.count_parameters(refiner)))

    # # Initialize weights of networks # paddle的参数化不同,参见API
    # encoder.apply(utils.network_utils.init_weights)
    # decoder.apply(utils.network_utils.init_weights)
    # merger.apply(utils.network_utils.init_weights)

    # Set up learning rate scheduler to decay learning rates dynamically
    encoder_lr_scheduler = paddle.optimizer.lr.MultiStepDecay(
        learning_rate=cfg.TRAIN.ENCODER_LEARNING_RATE,
        milestones=cfg.TRAIN.ENCODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA,
        verbose=True)
    decoder_lr_scheduler = paddle.optimizer.lr.MultiStepDecay(
        learning_rate=cfg.TRAIN.DECODER_LEARNING_RATE,
        milestones=cfg.TRAIN.DECODER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA,
        verbose=True)
    merger_lr_scheduler = paddle.optimizer.lr.MultiStepDecay(
        learning_rate=cfg.TRAIN.MERGER_LEARNING_RATE,
        milestones=cfg.TRAIN.MERGER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA,
        verbose=True)
    refiner_lr_scheduler = paddle.optimizer.lr.MultiStepDecay(
        learning_rate=cfg.TRAIN.REFINER_LEARNING_RATE,
        milestones=cfg.TRAIN.REFINER_LR_MILESTONES,
        gamma=cfg.TRAIN.GAMMA,
        verbose=True)
    # Set up solver
    # if cfg.TRAIN.POLICY == 'adam':
    encoder_solver = paddle.optimizer.Adam(learning_rate=encoder_lr_scheduler,
                                           parameters=encoder.parameters())
    decoder_solver = paddle.optimizer.Adam(learning_rate=decoder_lr_scheduler,
                                           parameters=decoder.parameters())
    merger_solver = paddle.optimizer.Adam(learning_rate=merger_lr_scheduler,
                                          parameters=merger.parameters())
    refiner_solver = paddle.optimizer.Adam(learning_rate=refiner_lr_scheduler,
                                           parameters=refiner.parameters())

    # if torch.cuda.is_available():
    #     encoder = torch.nn.DataParallel(encoder).cuda()
    #     decoder = torch.nn.DataParallel(decoder).cuda()
    #     merger = torch.nn.DataParallel(merger).cuda()

    # Set up loss functions
    bce_loss = paddle.nn.BCELoss()

    # Load pretrained model if exists
    init_epoch = 0
    best_iou = -1
    best_epoch = -1
    if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN:
        print('[INFO] %s Recovering from %s ...' %
              (dt.now(), cfg.CONST.WEIGHTS))
        # load
        encoder_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "encoder.pdparams"))
        encoder_solver_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "encoder_solver.pdopt"))
        encoder.set_state_dict(encoder_state_dict)
        encoder_solver.set_state_dict(encoder_solver_state_dict)
        decoder_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "decoder.pdparams"))
        decoder_solver_state_dict = paddle.load(
            os.path.join(cfg.CONST.WEIGHTS, "decoder_solver.pdopt"))
        decoder.set_state_dict(decoder_state_dict)
        decoder_solver.set_state_dict(decoder_solver_state_dict)

        if cfg.NETWORK.USE_MERGER:
            merger_state_dict = paddle.load(
                os.path.join(cfg.CONST.WEIGHTS, "merger.pdparams"))
            merger_solver_state_dict = paddle.load(
                os.path.join(cfg.CONST.WEIGHTS, "merger_solver.pdopt"))
            merger.set_state_dict(merger_state_dict)
            merger_solver.set_state_dict(merger_solver_state_dict)

        if cfg.NETWORK.USE_REFINER:
            refiner_state_dict = paddle.load(
                os.path.join(cfg.CONST.WEIGHTS, "refiner.pdparams"))
            refiner_solver_state_dict = paddle.load(
                os.path.join(cfg.CONST.WEIGHTS, "refiner_solver.pdopt"))
            refiner.set_state_dict(refiner_state_dict)
            refiner_solver.set_state_dict(refiner_solver_state_dict)

        print(
            '[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.'
            % (dt.now(), init_epoch, best_iou, best_epoch))

    # Summary writer for TensorBoard
    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat())
    log_dir = output_dir % 'logs'
    ckpt_dir = output_dir % 'checkpoints'
    # train_writer = SummaryWriter()
    # val_writer = SummaryWriter(os.path.join(log_dir, 'test'))
    train_writer = LogWriter(os.path.join(log_dir, 'train'))
    val_writer = LogWriter(os.path.join(log_dir, 'val'))

    # Training loop
    for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):
        # Tick / tock
        epoch_start_time = time()

        # Batch average meterics
        batch_time = utils.network_utils.AverageMeter()
        data_time = utils.network_utils.AverageMeter()
        encoder_losses = utils.network_utils.AverageMeter()
        refiner_losses = utils.network_utils.AverageMeter()

        # # switch models to training mode
        encoder.train()
        decoder.train()
        merger.train()
        refiner.train()

        batch_end_time = time()
        n_batches = len(train_data_loader)

        # print("****debug: length of train data loder",n_batches)
        for batch_idx, (rendering_images, ground_truth_volumes) in enumerate(
                train_data_loader()):
            # # debug
            # if batch_idx>1:
            #     break

            # Measure data time
            data_time.update(time() - batch_end_time)
            # print("****debug: batch_idx",batch_idx)
            # print(rendering_images.shape)
            # print(ground_truth_volumes.shape)
            # Get data from data loader
            rendering_images = utils.network_utils.var_or_cuda(
                rendering_images)
            ground_truth_volumes = utils.network_utils.var_or_cuda(
                ground_truth_volumes)

            # Train the encoder, decoder, and merger
            image_features = encoder(rendering_images)
            raw_features, generated_volumes = decoder(image_features)

            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:
                generated_volumes = merger(raw_features, generated_volumes)
            # else:
            #     mergered_volumes = paddle.mean(generated_volumes, aixs=1)

            encoder_loss = bce_loss(generated_volumes,
                                    ground_truth_volumes) * 10

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                generated_volumes = refiner(generated_volumes)
                refiner_loss = bce_loss(generated_volumes,
                                        ground_truth_volumes) * 10
            # else:
            #     refiner_loss = encoder_loss

            # Gradient decent
            encoder_solver.clear_grad()
            decoder_solver.clear_grad()
            merger_solver.clear_grad()
            refiner_solver.clear_grad()

            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:
                encoder_loss.backward(retain_graph=True)
                refiner_loss.backward()
            # else:
            #     encoder_loss.backward()

            encoder_solver.step()
            decoder_solver.step()
            merger_solver.step()
            refiner_solver.step()

            # Append loss to average metrics
            encoder_losses.update(encoder_loss.numpy())
            refiner_losses.update(refiner_loss.numpy())

            # Append loss to TensorBoard
            n_itr = epoch_idx * n_batches + batch_idx
            train_writer.add_scalar(tag='EncoderDecoder/BatchLoss',
                                    step=n_itr,
                                    value=encoder_loss.numpy())
            train_writer.add_scalar('Refiner/BatchLoss',
                                    value=refiner_loss.numpy(),
                                    step=n_itr)

            # Tick / tock
            batch_time.update(time() - batch_end_time)
            batch_end_time = time()
            if (batch_idx % int(cfg.CONST.INFO_BATCH)) == 0:
                print(
                    '[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f'
                    % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES,
                       batch_idx + 1, n_batches, batch_time.val, data_time.val,
                       encoder_loss.numpy(), refiner_loss.numpy()))

        # Append epoch loss to TensorBoard
        train_writer.add_scalar(tag='EncoderDecoder/EpochLoss',
                                step=epoch_idx + 1,
                                value=encoder_losses.avg)
        train_writer.add_scalar('Refiner/EpochLoss',
                                value=refiner_losses.avg,
                                step=epoch_idx + 1)

        # update scheduler each step
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()
        merger_lr_scheduler.step()
        refiner_lr_scheduler.step()

        # Tick / tock
        epoch_end_time = time()
        print(
            '[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f'
            % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time -
               epoch_start_time, encoder_losses.avg, refiner_losses.avg))

        # Update Rendering Views
        if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING:
            n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING)
            train_data_loader.dataset.set_n_views_rendering(n_views_rendering)
            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' %
                  (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES,
                   n_views_rendering))

        # Validate the training models
        iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader,
                       val_writer, encoder, decoder, merger, refiner)

        # Save weights to file
        if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            utils.network_utils.save_checkpoints(
                cfg, os.path.join(ckpt_dir,
                                  'ckpt-epoch-%04d' % (epoch_idx + 1)),
                epoch_idx + 1, encoder, encoder_solver, decoder,
                decoder_solver, merger, merger_solver, refiner, refiner_solver,
                best_iou, best_epoch)
        if iou > best_iou:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)

            best_iou = iou
            best_epoch = epoch_idx + 1
            utils.network_utils.save_checkpoints(
                cfg, os.path.join(ckpt_dir, 'best-ckpt'), epoch_idx + 1,
                encoder, encoder_solver, decoder, decoder_solver, merger,
                merger_solver, refiner, refiner_solver, best_iou, best_epoch)
Ejemplo n.º 7
0
def main(_run, _config, _log):
    if _run.observers:
        os.makedirs(f'{_run.observers[0].dir}/snapshots', exist_ok=True)
        for source_file, _ in _run.experiment_info['sources']:
            os.makedirs(os.path.dirname(
                f'{_run.observers[0].dir}/source/{source_file}'),
                        exist_ok=True)
            _run.observers[0].save_file(source_file, f'source/{source_file}')
        shutil.rmtree(f'{_run.observers[0].basedir}/_sources')

    set_seed(_config['seed'])
    cudnn.enabled = True
    cudnn.benchmark = True
    device = torch.device(f"cuda:{_config['gpu_id']}")

    resize_dim = _config['input_size']
    encoded_h = int(resize_dim[0] / 2**_config['n_pool'])
    encoded_w = int(resize_dim[1] / 2**_config['n_pool'])

    s_encoder = SupportEncoder(_config['path']['init_path'],
                               device)  #.to(device)
    q_encoder = QueryEncoder(_config['path']['init_path'],
                             device)  #.to(device)
    decoder = Decoder(input_res=(encoded_h, encoded_w),
                      output_res=resize_dim).to(device)

    _log.info('###### Load data ######')
    data_name = _config['dataset']
    if data_name == 'prostate':
        make_data = meta_data
    else:
        raise ValueError('Wrong config for dataset!')

    tr_dataset, val_dataset, ts_dataset = make_data(_config)
    trainloader = DataLoader(
        dataset=tr_dataset,
        batch_size=_config['batch_size'],
        shuffle=True,
        num_workers=_config['n_work'],
        pin_memory=False,  #True load data while training gpu
        drop_last=True)

    _log.info('###### Set optimizer ######')
    print(_config['optim'])
    optimizer = torch.optim.Adam(  #list(initializer.parameters()) +
        list(s_encoder.parameters()) + list(q_encoder.parameters()) +
        list(decoder.parameters()), _config['optim']['lr'])
    scheduler = MultiStepLR(optimizer,
                            milestones=_config['lr_milestones'],
                            gamma=0.1)
    pos_weight = torch.tensor([0.3, 1], dtype=torch.float).to(device)
    criterion = nn.BCELoss()

    if _config['record']:  ## tensorboard visualization
        _log.info('###### define tensorboard writer #####')
        _log.info(f'##### board/train_{_config["board"]}_{date()}')
        writer = SummaryWriter(f'board/train_{_config["board"]}_{date()}')

    iter_n_train = len(trainloader)
    _log.info('###### Training ######')
    for i_epoch in range(_config['n_steps']):
        loss_epoch = 0
        blank = torch.zeros([1, 256, 256]).to(device)

        for i_iter, sample_train in enumerate(trainloader):
            ## training stage
            optimizer.zero_grad()
            s_x = sample_train['s_x'].to(
                device)  # [B, Support, slice_num, 1, 256, 256]
            s_y = sample_train['s_y'].to(
                device)  # [B, Support, slice_num, 1, 256, 256]
            q_x = sample_train['q_x'].to(device)  #[B, slice_num, 1, 256, 256]
            q_y = sample_train['q_y'].to(device)  #[B, slice_num, 1, 256, 256]
            # loss_per_video = 0.0
            s_xi = s_x[:, :, 0, :, :, :]  # [B, Support, 1, 256, 256]
            s_yi = s_y[:, :, 0, :, :, :]

            # for s_idx in range(_config["n_shot"]):
            s_x_merge = s_xi.view(s_xi.size(0) * s_xi.size(1), 1, 256, 256)
            s_y_merge = s_yi.view(s_yi.size(0) * s_yi.size(1), 1, 256, 256)
            s_xi_encode_merge, _ = s_encoder(s_x_merge,
                                             s_y_merge)  # [B*S, 512, w, h]

            s_xi_encode = s_xi_encode_merge.view(s_yi.size(0), s_yi.size(1),
                                                 512, encoded_w, encoded_h)
            s_xi_encode_avg = torch.mean(s_xi_encode, dim=1)
            # s_xi_encode, _ = s_encoder(s_xi, s_yi)  # [B, 512, w, h]
            q_xi = q_x[:, 0, :, :, :]
            q_yi = q_y[:, 0, :, :, :]
            q_xi_encode, q_ft_list = q_encoder(q_xi)
            sq_xi = torch.cat((s_xi_encode_avg, q_xi_encode), dim=1)
            yhati = decoder(sq_xi, q_ft_list)  # [B, 1, 256, 256]
            loss = criterion(yhati, q_yi)
            # loss_per_video += loss
            # loss_per_video.backward()
            loss.backward()
            optimizer.step()
            loss_epoch += loss
            print(f"train, iter:{i_iter}/{iter_n_train}, iter_loss:{loss}",
                  end='\r')

            if _config['record'] and i_iter == 0:
                batch_i = 0
                frames = []
                frames += overlay_color(q_xi[batch_i],
                                        yhati[batch_i].round(),
                                        q_yi[batch_i],
                                        scale=_config['scale'])
                visual = make_grid(frames, normalize=True, nrow=2)
                writer.add_image("train/visual", visual, i_epoch)

                if _config['record'] and i_iter == 0:
                    batch_i = 0
                    frames = []
                    frames += overlay_color(q_xi[batch_i],
                                            yhati[batch_i].round(),
                                            q_yi[batch_i],
                                            scale=_config['scale'])
                    # frames += overlay_color(s_xi[batch_i], blank, s_yi[batch_i], scale=_config['scale'])
                    visual = make_grid(frames, normalize=True, nrow=5)
                    writer.add_image("valid/visual", visual, i_epoch)

        print(
            f"train - epoch:{i_epoch}/{_config['n_steps']}, epoch_loss:{loss_epoch}",
            end='\n')
        save_fname = f'{_run.observers[0].dir}/snapshots/last.pth'

        _run.log_scalar("training.loss", float(loss_epoch), i_epoch)
        if _config['record']:
            writer.add_scalar('loss/train_loss', loss_epoch, i_epoch)

        torch.save(
            {
                's_encoder': s_encoder.state_dict(),
                'q_encoder': q_encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, save_fname)

    writer.close()
Ejemplo n.º 8
0
class VaeGanModule(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        # Encoder
        self.encoder = Encoder(ngf=self.hparams.ngf, z_dim=self.hparams.z_dim)
        self.encoder.apply(weights_init)
        device = "cuda" if isinstance(self.hparams.gpus, int) else "cpu"
        # Decoder
        self.decoder = Decoder(ngf=self.hparams.ngf, z_dim=self.hparams.z_dim)
        self.decoder.apply(weights_init)
        # Discriminator
        self.discriminator = Discriminator()
        self.discriminator.apply(weights_init)

        # Losses
        self.criterionFeat = torch.nn.L1Loss()
        self.criterionGAN = GANLoss(gan_mode="lsgan")

        if self.hparams.use_vgg:
            self.criterion_perceptual_style = [Perceptual_Loss(device)]

    @staticmethod
    def reparameterize(mu, logvar, mode='train'):
        if mode == 'train':
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu

    def discriminate(self, fake_image, real_image):
        input_concat_fake = torch.cat(
            (fake_image.detach(), real_image),
            dim=1)  # non sono sicuro che .detach() sia necessario in lightning
        input_concat_real = torch.cat((real_image, real_image), dim=1)

        return (self.discriminator(input_concat_fake),
                self.discriminator(input_concat_real))

    def training_step(self, batch, batch_idx, optimizer_idx):
        x, _ = batch

        # train VAE
        if optimizer_idx == 0:

            # encode
            mu, log_var = self.encoder(x)
            z_repar = VaeGanModule.reparameterize(mu, log_var)

            # decode
            fake_image = self.decoder(z_repar)

            # reconstruction
            reconstruction_loss = self.criterionFeat(fake_image, x)
            kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) -
                                         log_var.exp())

            # Discriminate
            input_concat_fake = torch.cat((fake_image, x), dim=1)
            pred_fake = self.discriminator(input_concat_fake)

            # Losses
            loss_G_GAN = self.criterionGAN(pred_fake, True)
            if self.hparams.use_vgg:
                loss_G_perceptual = \
                    self.criterion_perceptual_style[0](fake_image, x)
            else:
                loss_G_perceptual = 0.0
            g_loss = (reconstruction_loss *
                      20) + kld_loss + loss_G_GAN + loss_G_perceptual

            # Results are collected in a TrainResult object
            result = pl.TrainResult(g_loss)
            result.log("rec_loss", reconstruction_loss * 10, prog_bar=True)
            result.log("loss_G_GAN", loss_G_GAN, prog_bar=True)
            result.log("kld_loss", kld_loss, prog_bar=True)
            result.log("loss_G_perceptual", loss_G_perceptual, prog_bar=True)

        # train Discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # Encode
            mu, log_var = self.encoder(x)
            z_repar = VaeGanModule.reparameterize(mu, log_var)

            # Decode
            fake_image = self.decoder(z_repar)

            # how well can it label as real?
            pred_fake, pred_real = self.discriminate(fake_image, x)

            # Fake loss
            d_loss_fake = self.criterionGAN(pred_fake, False)

            # Real Loss
            d_loss_real = self.criterionGAN(pred_real, True)

            # Total loss is average of prediction of fakes and reals
            loss_D = (d_loss_fake + d_loss_real) / 2

            # Results are collected in a TrainResult object
            result = pl.TrainResult(loss_D)
            result.log("loss_D_real", d_loss_real, prog_bar=True)
            result.log("loss_D_fake", d_loss_fake, prog_bar=True)

        return result

    def training_epoch_end(self, training_step_outputs):
        z_appr = torch.normal(mean=0,
                              std=1,
                              size=(16, self.hparams.z_dim),
                              device=training_step_outputs[0].minimize.device)

        # Generate images from latent vector
        sample_imgs = self.decoder(z_appr)
        grid = torchvision.utils.make_grid(sample_imgs,
                                           normalize=True,
                                           range=(-1, 1))

        # where to save the image
        path = os.path.join(self.hparams.generated_images_folder,
                            f"generated_images_{self.current_epoch}.png")
        torchvision.utils.save_image(sample_imgs,
                                     path,
                                     normalize=True,
                                     range=(-1, 1))

        # Log images in tensorboard
        self.logger.experiment.add_image(f'generated_images', grid,
                                         self.current_epoch)

        # Epoch level metrics
        epoch_loss = torch.mean(
            torch.stack([x['minimize'] for x in training_step_outputs]))
        results = pl.TrainResult()
        results.log("epoch_loss", epoch_loss, prog_bar=False)

        return results

    def validation_step(self, batch, batch_idx):
        x, _ = batch

        # Encode
        mu, log_var = self.encoder(x)
        z_repar = VaeGanModule.reparameterize(mu, log_var)

        # Decode
        recons = self.decoder(z_repar)
        reconstruction_loss = nn.functional.mse_loss(recons, x)

        # Results are collected in a EvalResult object
        result = pl.EvalResult(checkpoint_on=reconstruction_loss)
        return result

    testing_step = validation_step

    def configure_optimizers(self):
        params_vae = concat_generators(self.encoder.parameters(),
                                       self.decoder.parameters())
        opt_vae = torch.optim.Adam(params_vae,
                                   lr=self.hparams.learning_rate_vae)

        parameters_discriminator = self.discriminator.parameters()
        opt_d = torch.optim.Adam(parameters_discriminator,
                                 lr=self.hparams.learning_rate_d)

        return [opt_vae, opt_d]

    @staticmethod
    def add_argparse_args(parser):

        parser.add_argument('--generated_images_folder',
                            required=False,
                            default="./output",
                            type=str)
        parser.add_argument('--ngf', type=int, default=128)
        parser.add_argument('--z_dim', type=int, default=128)
        parser.add_argument('--learning_rate_vae',
                            default=1e-03,
                            required=False,
                            type=float)
        parser.add_argument('--learning_rate_d',
                            default=1e-03,
                            required=False,
                            type=float)
        parser.add_argument("--use_vgg", action="store_true", default=False)

        return parser
def main():
    a = argparse.ArgumentParser()
    a.add_argument("--debug", "-D", action="store_true")
    a.add_argument("--loss_only", "-L", action="store_true")
    args = a.parse_args()

    print("MODEL ID: {}".format(C.id))
    print("DEBUG MODE: {}".format(['OFF', 'ON'][args.debug]))

    if not args.debug:
        summary_writer = SummaryWriter(C.log_dpath)
    """ Load DataLoader """
    MSVD = _MSVD(C)
    vocab = MSVD.vocab
    train_data_loader = iter(cycle(MSVD.train_data_loader))
    val_data_loader = iter(cycle(MSVD.val_data_loader))

    print('n_vocabs: {} ({}), n_words: {} ({}). MIN_COUNT: {}'.format(
        vocab.n_vocabs, vocab.n_vocabs_untrimmed, vocab.n_words,
        vocab.n_words_untrimmed, C.min_count))
    """ Build Decoder """
    decoder = Decoder(model_name=C.decoder_model,
                      n_layers=C.decoder_n_layers,
                      encoder_size=C.encoder_output_size,
                      embedding_size=C.embedding_size,
                      embedding_scale=C.embedding_scale,
                      hidden_size=C.decoder_hidden_size,
                      attn_size=C.decoder_attn_size,
                      output_size=vocab.n_vocabs,
                      embedding_dropout=C.embedding_dropout,
                      dropout=C.decoder_dropout,
                      out_dropout=C.decoder_out_dropout)
    decoder = decoder.to(C.device)
    decoder_loss_func = nn.CrossEntropyLoss()
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=C.decoder_learning_rate,
                                   weight_decay=C.decoder_weight_decay,
                                   amsgrad=C.decoder_use_amsgrad)
    decoder_lambda = torch.autograd.Variable(torch.tensor(0.001),
                                             requires_grad=True)
    decoder_lambda = decoder_lambda.to(C.device)
    """ Build Reconstructor """
    if C.use_recon:
        if C.reconstructor_type == "local":
            reconstructor = LocalReconstructor(
                model_name=C.reconstructor_model,
                n_layers=C.reconstructor_n_layers,
                decoder_hidden_size=C.decoder_hidden_size,
                hidden_size=C.reconstructor_hidden_size,
                dropout=C.reconstructor_dropout,
                decoder_dropout=C.reconstructor_decoder_dropout,
                attn_size=C.reconstructor_attn_size)
        elif C.reconstructor_type == "global":
            reconstructor = GlobalReconstructor(
                model_name=C.reconstructor_model,
                n_layers=C.reconstructor_n_layers,
                decoder_hidden_size=C.decoder_hidden_size,
                hidden_size=C.reconstructor_hidden_size,
                dropout=C.reconstructor_dropout,
                decoder_dropout=C.reconstructor_decoder_dropout,
                caption_max_len=C.caption_max_len)
        else:
            raise NotImplementedError("Unknown reconstructor: {}".format(
                C.reconstructor_type))
        reconstructor = reconstructor.to(C.device)
        reconstructor_loss_func = nn.MSELoss()
        reconstructor_optimizer = optim.Adam(
            reconstructor.parameters(),
            lr=C.reconstructor_learning_rate,
            weight_decay=C.reconstructor_weight_decay,
            amsgrad=C.reconstructor_use_amsgrad)
        reconstructor_lambda = torch.autograd.Variable(torch.tensor(0.01),
                                                       requires_grad=True)
        reconstructor_lambda = reconstructor_lambda.to(C.device)
        loss_lambda = torch.autograd.Variable(torch.tensor(1.),
                                              requires_grad=True)
        loss_lambda = loss_lambda.to(C.device)
    """ Train """
    train_loss = 0
    if C.use_recon:
        train_dec_loss = 0
        train_rec_loss = 0
    for iteration, batch in enumerate(train_data_loader, 1):
        if C.use_recon:
            loss, decoder_loss, _, recon_loss = dec_rec_step(
                batch,
                decoder,
                decoder_loss_func,
                decoder_lambda,
                decoder_optimizer,
                reconstructor,
                reconstructor_loss_func,
                reconstructor_lambda,
                reconstructor_optimizer,
                loss_lambda,
                is_train=True)
            train_dec_loss += decoder_loss
            train_rec_loss += recon_loss
        else:
            loss, _ = dec_step(batch,
                               decoder,
                               decoder_loss_func,
                               decoder_lambda,
                               decoder_optimizer,
                               is_train=True)
        train_loss += loss
        """ Log Train Progress """
        if args.debug or iteration % C.log_every == 0:
            train_loss /= C.log_every
            if C.use_recon:
                train_dec_loss /= C.log_every
                train_rec_loss /= C.log_every

            if not args.debug:
                summary_writer.add_scalar(C.tx_train_loss, train_loss,
                                          iteration)
                summary_writer.add_scalar(C.tx_lambda_decoder,
                                          decoder_lambda.item(), iteration)
                if C.use_recon:
                    summary_writer.add_scalar(C.tx_train_loss_decoder,
                                              train_dec_loss, iteration)
                    summary_writer.add_scalar(C.tx_train_loss_reconstructor,
                                              train_rec_loss, iteration)
                    summary_writer.add_scalar(C.tx_lambda_reconstructor,
                                              reconstructor_lambda.item(),
                                              iteration)
                    summary_writer.add_scalar(C.tx_lambda, loss_lambda.item(),
                                              iteration)

            if C.use_recon:
                print(
                    "Iter {} / {} ({:.1f}%): loss {:.5f} (dec {:.5f} + rec {:.5f})"
                    .format(iteration, C.train_n_iteration,
                            iteration / C.train_n_iteration * 100, train_loss,
                            train_dec_loss, train_rec_loss))
            else:
                print("Iter {} / {} ({:.1f}%): loss {:.5f}".format(
                    iteration, C.train_n_iteration,
                    iteration / C.train_n_iteration * 100, train_loss))

            train_loss = 0
            if C.use_recon:
                train_dec_loss = 0
                train_rec_loss = 0
        """ Log Validation Progress """
        if args.debug or iteration % C.validate_every == 0:
            val_loss = 0
            val_dec_loss = 0
            val_rec_loss = 0
            gt_captions = []
            pd_captions = []
            for batch in val_data_loader:
                if C.use_recon:
                    loss, decoder_loss, decoder_output_indices, recon_loss = dec_rec_step(
                        batch,
                        decoder,
                        decoder_loss_func,
                        decoder_lambda,
                        decoder_optimizer,
                        reconstructor,
                        reconstructor_loss_func,
                        reconstructor_lambda,
                        reconstructor_optimizer,
                        loss_lambda,
                        is_train=False)
                    val_dec_loss += decoder_loss * C.batch_size
                    val_rec_loss += recon_loss * C.batch_size
                else:
                    loss, decoder_output_indices = dec_step(batch,
                                                            decoder,
                                                            decoder_loss_func,
                                                            decoder_lambda,
                                                            decoder_optimizer,
                                                            is_train=False)
                val_loss += loss * C.batch_size

                _, _, targets = batch
                gt_idxs = targets.cpu().numpy()
                pd_idxs = decoder_output_indices.cpu().numpy()
                gt_captions += convert_idxs_to_sentences(
                    gt_idxs, vocab.idx2word, vocab.word2idx['<EOS>'])
                pd_captions += convert_idxs_to_sentences(
                    pd_idxs, vocab.idx2word, vocab.word2idx['<EOS>'])

                if len(pd_captions) >= C.n_val:
                    assert len(gt_captions) == len(pd_captions)
                    gt_captions = gt_captions[:C.n_val]
                    pd_captions = pd_captions[:C.n_val]
                    break
            val_loss /= C.n_val
            val_dec_loss /= C.n_val
            val_rec_loss /= C.n_val

            if C.use_recon:
                print(
                    "[Validation] Iter {} / {} ({:.1f}%): loss {:.5f} (dec {:.5f} + rec {:5f})"
                    .format(iteration, C.train_n_iteration,
                            iteration / C.train_n_iteration * 100, val_loss,
                            val_dec_loss, val_rec_loss))
            else:
                print(
                    "[Validation] Iter {} / {} ({:.1f}%): loss {:.5f}".format(
                        iteration, C.train_n_iteration,
                        iteration / C.train_n_iteration * 100, val_loss))

            caption_pairs = [(gt, pred)
                             for gt, pred in zip(gt_captions, pd_captions)]
            caption_pairs = sample_n(caption_pairs,
                                     min(C.n_val_logs, C.batch_size))
            caption_log = "\n\n".join([
                "[GT] {}  \n[PD] {}".format(gt, pd) for gt, pd in caption_pairs
            ])

            if not args.debug:
                summary_writer.add_scalar(C.tx_val_loss, val_loss, iteration)
                if C.use_recon:
                    summary_writer.add_scalar(C.tx_val_loss_decoder,
                                              val_dec_loss, iteration)
                    summary_writer.add_scalar(C.tx_val_loss_reconstructor,
                                              val_rec_loss, iteration)
                summary_writer.add_text(C.tx_predicted_captions, caption_log,
                                        iteration)
        """ Log Test Progress """
        if not args.loss_only and (args.debug
                                   or iteration % C.test_every == 0):
            pd_vid_caption_pairs = []
            score_data_loader = MSVD.score_data_loader
            print("[Test] Iter {} / {} ({:.1f}%)".format(
                iteration, C.train_n_iteration,
                iteration / C.train_n_iteration * 100))
            for search_method in C.search_methods:
                if isinstance(search_method, str):
                    method = search_method
                    search_method_id = search_method
                if isinstance(search_method, tuple):
                    method = search_method[0]
                    search_method_id = "-".join(
                        (str(s) for s in search_method))
                scores = evaluate(C, MSVD, score_data_loader, decoder,
                                  search_method)
                score_summary = " ".join([
                    "{}: {:.3f}".format(score, scores[score])
                    for score in C.scores
                ])
                print("\t{}: {}".format(search_method_id, score_summary))
                if not args.debug:
                    for score in C.scores:
                        summary_writer.add_scalar(
                            C.tx_score[search_method_id][score], scores[score],
                            iteration)
        """ Save checkpoint """
        if iteration % C.save_every == 0:
            if not os.path.exists(C.save_dpath):
                os.makedirs(C.save_dpath)
            fpath = os.path.join(C.save_dpath,
                                 "{}_checkpoint.tar".format(iteration))

            if C.use_recon:
                torch.save(
                    {
                        'iteration': iteration,
                        'dec': decoder.state_dict(),
                        'rec': reconstructor.state_dict(),
                        'dec_opt': decoder_optimizer.state_dict(),
                        'rec_opt': reconstructor_optimizer.state_dict(),
                        'loss': loss,
                        'config': C,
                    }, fpath)
            else:
                torch.save(
                    {
                        'iteration': iteration,
                        'dec': decoder.state_dict(),
                        'dec_opt': decoder_optimizer.state_dict(),
                        'loss': loss,
                        'config': C,
                    }, fpath)

        if iteration == C.train_n_iteration:
            break
Ejemplo n.º 10
0
class Model(pl.LightningModule):

    def __init__(self, cfg_network: DictConfig, cfg_tester: DictConfig):
        super().__init__()
        self.cfg_network = cfg_network
        self.cfg_tester = cfg_tester

        # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use
        torch.backends.cudnn.benchmark = True

        # Set up networks
        self.encoder = Encoder(cfg_network)
        self.decoder = Decoder(cfg_network)
        self.refiner = Refiner(cfg_network)
        self.merger = Merger(cfg_network)
        
        # Initialize weights of networks
        self.encoder.apply(utils.network_utils.init_weights)
        self.decoder.apply(utils.network_utils.init_weights)
        self.refiner.apply(utils.network_utils.init_weights)
        self.merger.apply(utils.network_utils.init_weights)
        
        self.bce_loss = nn.BCELoss()

    def configure_optimizers(self):
        params = self.cfg_network.optimization
        # Set up solver
        if params.policy == 'adam':
            encoder_solver = optim.Adam(filter(lambda p: p.requires_grad, self.encoder.parameters()),
                                            lr=params.encoder_lr,
                                            betas=params.betas)
            decoder_solver = optim.Adam(self.decoder.parameters(),
                                            lr=params.decoder_lr,
                                            betas=params.betas)
            refiner_solver = optim.Adam(self.refiner.parameters(),
                                            lr=params.refiner_lr,
                                            betas=params.betas)
            merger_solver = optim.Adam(self.merger.parameters(),
                                             lr=params.merger_lr,
                                             betas=params.betas)
        elif params.policy == 'sgd':
            encoder_solver = optim.SGD(filter(lambda p: p.requires_grad, self.encoder.parameters()),
                                            lr=params.encoder_lr,
                                            momentum=params.momentum)
            decoder_solver = optim.SGD(self.decoder.parameters(),
                                            lr=params.decoder_lr,
                                            momentum=params.momentum)
            refiner_solver = optim.SGD(self.refiner.parameters(),
                                            lr=params.refiner_lr,
                                            momentum=params.momentum)
            merger_solver = optim.SGD(self.merger.parameters(),
                                            lr=params.merger_lr,
                                            momentum=params.momentum)
        else:
            raise Exception('[FATAL] %s Unknown optimizer %s.' % (dt.now(), params.policy))
            
            # Set up learning rate scheduler to decay learning rates dynamically
        encoder_lr_scheduler = optim.lr_scheduler.MultiStepLR(encoder_solver,
                                                                    milestones=params.encoder_lr_milestones,
                                                                    gamma=params.gamma)
        decoder_lr_scheduler = optim.lr_scheduler.MultiStepLR(decoder_solver,
                                                                    milestones=params.decoder_lr_milestones,
                                                                    gamma=params.gamma)
        refiner_lr_scheduler = optim.lr_scheduler.MultiStepLR(refiner_solver,
                                                                    milestones=params.refiner_lr_milestones,
                                                                    gamma=params.gamma)
        merger_lr_scheduler = optim.lr_scheduler.MultiStepLR(merger_solver,
                                                                milestones=params.merger_lr_milestones,
                                                                gamma=params.gamma)
        
        return [encoder_solver, decoder_solver, refiner_solver, merger_solver], \
               [encoder_lr_scheduler, decoder_lr_scheduler, refiner_lr_scheduler, merger_lr_scheduler]
    
    def _fwd(self, batch):
        taxonomy_names, sample_names, rendering_images, ground_truth_volumes = batch

        image_features = self.encoder(rendering_images)
        raw_features, generated_volumes = self.decoder(image_features)

        if self.cfg_network.use_merger and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_merger:
            generated_volumes = self.merger(raw_features, generated_volumes)
        else:
            generated_volumes = torch.mean(generated_volumes, dim=1)
        encoder_loss = self.bce_loss(generated_volumes, ground_truth_volumes) * 10
        
        if self.cfg_network.use_refiner and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_refiner:
            generated_volumes = self.refiner(generated_volumes)
            refiner_loss = self.bce_loss(generated_volumes, ground_truth_volumes) * 10
        else:
            refiner_loss = encoder_loss
        
        return generated_volumes, encoder_loss, refiner_loss

    def training_step(self, batch, batch_idx, optimizer_idx):
        (opt_enc, opt_dec, opt_ref, opt_merg) = self.optimizers()
        
        generated_volumes, encoder_loss, refiner_loss = self._fwd(batch)
        
        self.log('loss/EncoderDecoder', encoder_loss, 
                 prog_bar=True, logger=True, on_step=True, on_epoch=True)
        self.log('loss/Refiner', refiner_loss, 
                 prog_bar=True, logger=True, on_step=True, on_epoch=True)

        if self.cfg_network.use_refiner and self.current_epoch >= self.cfg_network.optimization.epoch_start_use_refiner:
            self.manual_backward(encoder_loss, opt_enc, retain_graph=True)
            self.manual_backward(refiner_loss, opt_ref)
        else:
            self.manual_backward(encoder_loss, opt_enc)
            
        for opt in self.optimizers():
            opt.step()
            opt.zero_grad()

    def training_epoch_end(self, outputs) -> None:
        # Update Rendering Views
        if self.cfg_network.update_n_views_rendering:
            n_views_rendering = self.trainer.datamodule.update_n_views_rendering()
            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' %
                  (dt.now(), self.current_epoch + 2, self.trainer.max_epochs, n_views_rendering))

    def _eval_step(self, batch, batch_idx):
        # SUPPORTS ONLY BATCH_SIZE=1
        taxonomy_names, sample_names, rendering_images, ground_truth_volumes = batch
        taxonomy_id = taxonomy_names[0]
        sample_name = sample_names[0]

        generated_volumes, encoder_loss, refiner_loss = self._fwd(batch)

        self.log('val_loss/EncoderDecoder', encoder_loss, prog_bar=True,
                 logger=True, on_step=True, on_epoch=True)

        self.log('val_loss/Refiner', refiner_loss, prog_bar=True,
                 logger=True, on_step=True, on_epoch=True)

        # IoU per sample
        sample_iou = []
        for th in self.cfg_tester.voxel_thresh:
            _volume = torch.ge(generated_volumes, th).float()
            intersection = torch.sum(_volume.mul(ground_truth_volumes)).float()
            union = torch.sum(
                torch.ge(_volume.add(ground_truth_volumes), 1)).float()
            sample_iou.append((intersection / union).item())

        # Print sample loss and IoU
        n_samples = -1
        print('\n[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' %
              (dt.now(), batch_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(),
               refiner_loss.item(), ['%.4f' % si for si in sample_iou]))

        return {
            'taxonomy_id': taxonomy_id,
            'sample_name': sample_name,
            'sample_iou': sample_iou
        }
        
    def _eval_epoch_end(self, outputs):
        # Load taxonomies of dataset
        taxonomies = []
        taxonomy_path = self.trainer.datamodule.get_test_taxonomy_file_path()
        with open(taxonomy_path, encoding='utf-8') as file:
            taxonomies = json.loads(file.read())
        taxonomies = {t['taxonomy_id']: t for t in taxonomies}

        test_iou = {}
        for output in outputs:
            taxonomy_id, sample_name, sample_iou = output[
                'taxonomy_id'], output['sample_name'], output['sample_iou']
            if taxonomy_id not in test_iou:
                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}
            test_iou[taxonomy_id]['n_samples'] += 1
            test_iou[taxonomy_id]['iou'].append(sample_iou)

        mean_iou = []
        for taxonomy_id in test_iou:
            test_iou[taxonomy_id]['iou'] = torch.mean(
                torch.tensor(test_iou[taxonomy_id]['iou']), dim=0)
            mean_iou.append(test_iou[taxonomy_id]['iou']
                            * test_iou[taxonomy_id]['n_samples'])
        n_samples = len(outputs)
        mean_iou = torch.stack(mean_iou)
        mean_iou = torch.sum(mean_iou, dim=0) / n_samples

        # Print header
        print('============================ TEST RESULTS ============================')
        print('Taxonomy', end='\t')
        print('#Sample', end='\t')
        print(' Baseline', end='\t')
        for th in self.cfg_tester.voxel_thresh:
            print('t=%.2f' % th, end='\t')
        print()
        # Print body
        for taxonomy_id in test_iou:
            print('%s' % taxonomies[taxonomy_id]
                  ['taxonomy_name'].ljust(8), end='\t')
            print('%d' % test_iou[taxonomy_id]['n_samples'], end='\t')
            if 'baseline' in taxonomies[taxonomy_id]:
                n_views_rendering = self.trainer.datamodule.get_n_views_rendering()
                print('%.4f' % taxonomies[taxonomy_id]['baseline']
                      ['%d-view' % n_views_rendering], end='\t\t')
            else:
                print('N/a', end='\t\t')

            for ti in test_iou[taxonomy_id]['iou']:
                print('%.4f' % ti, end='\t')
            print()
        # Print mean IoU for each threshold
        print('Overall ', end='\t\t\t\t')
        for mi in mean_iou:
            print('%.4f' % mi, end='\t')
        print('\n')

        max_iou = torch.max(mean_iou)
        self.log('Refiner/IoU', max_iou, prog_bar=True, on_epoch=True)
    
    def validation_step(self, batch, batch_idx):
        return self._eval_step(batch, batch_idx)
        
    def validation_epoch_end(self, outputs):
        self._eval_epoch_end(outputs)
        
    def test_step(self, batch, batch_idx):
        return self._eval_step(batch, batch_idx)
        
    def test_epoch_end(self, outputs):
        self._eval_epoch_end(outputs)

    def get_progress_bar_dict(self):
        # don't show the loss as it's None
        items = super().get_progress_bar_dict()
        items.pop("loss", None)
        return items
Ejemplo n.º 11
0
class Net(torch.nn.Module):
    def __init__(self, config):
        super(Net, self).__init__()
        self.encoder_type = config["encoder_type"]
        self.decoder = Decoder(layer_channels[self.encoder_type])
        self.encoder = get_encoder(config)
        self.get_skip_layer()

    def forward(self, x):
        sk_1, sk_2, sk_3, sk_4, x = self.encoder(x)
        sk_1 = self.skip_1(sk_1)
        sk_2 = self.skip_2(sk_2)
        sk_3 = self.skip_3(sk_3)
        sk_4 = self.skip_4(sk_4)
        return self.decoder(x, sk_1, sk_2, sk_3, sk_4)

    def load_model(self, prep_path, save_path):
        if os.path.exists(save_path):
            print("load from saved model:" + save_path + '...')
            checkpoint = torch.load(save_path)
            self.load_state_dict(checkpoint['model_state_dict'])
            ech = checkpoint['epoch']
            self.eval()
            print("load complete")
            return ech
        else:
            if self.load_pre_from_local():
                print("load pre-parameters:" + prep_path + '...')
                prep = torch.load(prep_path)
                model_dict = self.encoder.state_dict()
                prep = self.parameter_rename(prep, model_dict)
                pre_trained_dict = {
                    k: v
                    for k, v in prep.items() if k in model_dict
                }
                model_dict.update(pre_trained_dict)
                self.encoder.load_state_dict(model_dict)
                print("load complete")
            else:
                print("use pretrained model", self.encoder_type,
                      " from torchvision")
            return 0

    def save_model(self, ech, save_path):
        torch.save({
            'epoch': ech,
            'model_state_dict': self.state_dict(),
        }, save_path)

    def parameter_rename(self, org_dict, target_dict):
        if self.encoder_type == "vgg19" or self.encoder_type == "vgg19_bn":
            org_list = []
            target_list = []
            for k, _ in target_dict.items():
                if k.find("batches") < 0:
                    target_list.append(k)
            for k, _ in org_dict.items():
                if k.find("batches") < 0:
                    org_list.append(k)
            replace_index = range(len(target_list))
            for i in replace_index:
                org_dict[target_list[i]] = org_dict.pop(org_list[i])
            return org_dict
        elif self.encoder_type == "resnet152":
            return target_dict

    def get_skip_layer(self):
        channel = layer_channels[self.encoder_type]
        if channel is not None:
            self.skip_1 = torch.nn.Conv2d(channel[3], channel[3], 1, 1, 0)
            self.skip_2 = torch.nn.Conv2d(channel[2], channel[2], 1, 1, 0)
            self.skip_3 = torch.nn.Conv2d(channel[1], channel[1], 1, 1, 0)
            self.skip_4 = torch.nn.Conv2d(channel[0], channel[0], 1, 1, 0)
        else:
            raise RuntimeError("invalid encoder type")

    def load_pre_from_local(self):
        return self.encoder_type == "vgg19" or self.encoder_type == "vgg19_bn"

    def optimizer_by_layer(self, encoder_lr, decoder_lr):
        params = [{
            "params": self.encoder.parameters(),
            "lr": encoder_lr
        }, {
            "params": self.decoder.parameters(),
            "lr": decoder_lr
        }, {
            "params": self.skip_1.parameters(),
            "lr": encoder_lr
        }, {
            "params": self.skip_2.parameters(),
            "lr": encoder_lr
        }, {
            "params": self.skip_3.parameters(),
            "lr": encoder_lr
        }, {
            "params": self.skip_4.parameters(),
            "lr": encoder_lr
        }]
        return optim.Adam(params=params, lr=encoder_lr)
Ejemplo n.º 12
0
                                                    num_workers=4)

    VOCAB_SIZE = train_dataset.vocab.num_words
    SEQ_LEN = train_dataset.vocab.max_sentence_len

    encoder = Encoder(args.ENCODER_OUTPUT_SIZE).to(device)
    decoder = Decoder(embed_size=args.EMBED_SIZE,
                      hidden_size=args.HIDDEN_SIZE,
                      attention_size=args.ATTENTION_SIZE,
                      vocab_size=VOCAB_SIZE,
                      encoder_size=2048,
                      device=device,
                      seq_len=SEQ_LEN + 2).to(device)

    encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=args.LR)
    decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.LR)

    criterion = nn.CrossEntropyLoss(
    )  # 나중에 loss 계산 할때 패딩은 모두 없앨것이므로 ignore index를 설정하지 않는다

    train_losses = []
    validation_losses = []

    for epoch in range(args.NUM_EPOCHS):
        train_loss = 0
        validation_loss = 0
        encoder.train()
        decoder.train()
        for idx, (img, caption_5,
                  caption_lengths_5) in enumerate(train_loader):
            origin_img = img