Пример #1
0
def init_model():
    torch.backends.cudnn.benchmark = True
    get_logger().info('Initializing classification model...')
    # model = HighResNet(dropout_rate=config.DROPOUT_RATE).to(config.DEVICE)
    model = HighSEResNeXt(dropout_rate=config.DROPOUT_RATE).to(config.DEVICE)
    # model = HighSEResNeXt2(dropout_rate=config.DROPOUT_RATE).to(config.DEVICE)
    # model = HighCbamResNet(dropout_rate=config.DROPOUT_RATE).to(config.DEVICE)

    # criterion = torch.nn.BCEWithLogitsLoss()
    label_weight = torch.tensor([1, 1, 1, 1, 1, 2]).to(
        config.DEVICE, dtype=torch.float)
    criterion = WeightedBCE(label_weight)
    # criterion = FocalBCELoss(
    #    bce_weight=0.7, label_weight=label_weight, gamma=2)
    # criterion = torch.nn.BCEWithLogitsLoss()
    '''
    optimizer = torch.optim.SGD([{'params': model.parameters()}],
                                lr=config.SGD_LR,
                                momentum=config.MOMENTUM,
                                weight_decay=config.WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, config.ITER_PER_CYCLE, config.MIN_LR)
    '''
    optimizer = optim.Adam([{'params': model.parameters()}], lr=config.ADAM_LR)
    mile_stones = [1, 2]
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, mile_stones, gamma=0.5, last_epoch=-1)

    amp.register_float_function(torch, 'sigmoid')
    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    return model, criterion, optimizer, scheduler
Пример #2
0
def evaluate(gpu: int, config: dict, shared_dict, barrier, eval_ds, backbone):
    # --- Setup DistributedDataParallel --- #
    rank = config["nr"] * config["gpus"] + gpu
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='env://',
                                         world_size=config["world_size"],
                                         rank=rank)

    if gpu == 0:
        print("# --- Start evaluating --- #")

    # Choose device
    torch.cuda.set_device(gpu)

    # --- Online transform performed on the device (GPU):
    eval_online_cuda_transform = data_transforms.get_eval_online_cuda_transform(
        config)

    if "samples" in config:
        rng_samples = random.Random(0)
        eval_ds = torch.utils.data.Subset(
            eval_ds, rng_samples.sample(range(len(eval_ds)),
                                        config["samples"]))
        # eval_ds = torch.utils.data.Subset(eval_ds, range(config["samples"]))

    eval_sampler = torch.utils.data.distributed.DistributedSampler(
        eval_ds, num_replicas=config["world_size"], rank=rank)

    eval_ds = torch.utils.data.DataLoader(
        eval_ds,
        batch_size=config["optim_params"]["eval_batch_size"],
        pin_memory=True,
        sampler=eval_sampler,
        num_workers=config["num_workers"])

    model = FrameFieldModel(config,
                            backbone=backbone,
                            eval_transform=eval_online_cuda_transform)
    model.cuda(gpu)

    if config["use_amp"] and APEX_AVAILABLE:
        amp.register_float_function(torch, 'sigmoid')
        model = amp.initialize(model, opt_level="O1")
    elif config["use_amp"] and not APEX_AVAILABLE and gpu == 0:
        print_utils.print_warning(
            "WARNING: Cannot use amp because the apex library is not available!"
        )

    # Wrap the model for distributed training
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])

    evaluator = Evaluator(gpu,
                          config,
                          shared_dict,
                          barrier,
                          model,
                          run_dirpath=config["eval_params"]["run_dirpath"])
    split_name = config["fold"][0]
    evaluator.evaluate(split_name, eval_ds)
Пример #3
0
def get_model_optimizer_and_scheduler(cfg, seed=0):
    amp.register_float_function(torch, 'sigmoid')
    r"""Return the networks, the optimizers, and the schedulers. We will
    first set the random seed to a fixed value so that each GPU copy will be
    initialized to have the same network weights. We will then use different
    random seeds for different GPUs. After this we will wrap the generator
    with a moving average model if applicable. It is followed by getting the
    optimizers, amp initialization, and data distributed data parallel wrapping.

    Args:
        cfg (obj): Global configuration.
        seed (int): Random seed.

    Returns:
        (dict):
          - net_G (obj): Generator network object.
          - net_D (obj): Discriminator network object.
          - opt_G (obj): Generator optimizer object.
          - opt_D (obj): Discriminator optimizer object.
          - sch_G (obj): Generator optimizer scheduler object.
          - sch_D (obj): Discriminator optimizer scheduler object.
    """
    # We first set the random seed to be the same so that we initialize each
    # copy of the network in exactly the same way so that they have the same
    # weights and other parameters. The true seed will be the seed.
    set_random_seed(seed, by_rank=False)
    # Construct networks
    lib_G = importlib.import_module(cfg.gen.type)
    lib_D = importlib.import_module(cfg.dis.type)
    net_G = lib_G.Generator(cfg.gen, cfg.data).to('cuda')
    net_D = lib_D.Discriminator(cfg.dis, cfg.data).to('cuda')
    print('Initialize net_G and net_D weights using '
          'type: {} gain: {}'.format(cfg.trainer.init.type,
                                     cfg.trainer.init.gain))
    init_bias = getattr(cfg.trainer.init, 'bias', None)
    net_G.apply(
        weights_init(cfg.trainer.init.type, cfg.trainer.init.gain, init_bias))
    net_D.apply(
        weights_init(cfg.trainer.init.type, cfg.trainer.init.gain, init_bias))
    # Different GPU copies of the same model will receive noises
    # initialized with different random seeds (if applicable) thanks to the
    # set_random_seed command (GPU #K has random seed = args.seed + K).
    set_random_seed(seed, by_rank=True)
    print('net_G parameter count: {:,}'.format(_calculate_model_size(net_G)))
    print('net_D parameter count: {:,}'.format(_calculate_model_size(net_D)))

    # Optimizer
    opt_G = get_optimizer(cfg.gen_opt, net_G)
    opt_D = get_optimizer(cfg.dis_opt, net_D)

    net_G, net_D, opt_G, opt_D = \
        wrap_model_and_optimizer(cfg, net_G, net_D, opt_G, opt_D)

    # Scheduler
    sch_G = get_scheduler(cfg.gen_opt, opt_G)
    sch_D = get_scheduler(cfg.dis_opt, opt_D)

    return net_G, net_D, opt_G, opt_D, sch_G, sch_D
Пример #4
0
def get_model_optimizer(args,CFG):
    model=I3D_SGA_STD(args.dropout_rate,args.expand_k,
                                freeze_backbone=not (args.train_backbone), freeze_blocks=args.freeze_blocks,freeze_bn=not(args.train_bn),
                                 pretrained_backbone=args.pretrained_backbone,pretrained_path=CFG.I3D_MODEL_PATH,
                                 freeze_bn_statics=args.freeze_bn_sta).cuda()

    optimizer=get_optimizer(args,model)

    opt_level='O1'
    amp.init(allow_banned=True)
    amp.register_float_function(torch,'softmax')
    amp.register_float_function(torch,'sigmoid')
    model,optimizer=amp.initialize(model,optimizer,opt_level=opt_level,keep_batchnorm_fp32=None)
    model=BalancedDataParallel(int(args.batch_size*2*args.clip_num/len(args.gpus)*args.gpu0sz),model,dim=0,device_ids=args.gpus)

    return model,optimizer
Пример #5
0
    def __init__(self, args, num_labels, emb_freeze=False, class_weight=None):
        super().__init__()
        self.num_labels = num_labels
        self.emb_freeze = emb_freeze
        self.class_weight = None
        if class_weight is not None:
            self.class_weight = torch.tensor(class_weight)

        self.char_embedding = nn.Embedding(args.max_chars + 2, args.dim)

        self.covns = nn.ModuleList([
            nn.Conv2d(in_channels=1,
                      out_channels=args.filter_num,
                      kernel_size=(k, self.char_embedding.embedding_dim),
                      padding=0) for k in args.filters
        ])

        self.dropout = nn.Dropout(args.dropout_ratio)
        self.tanh = nn.Tanh()

        filter_dim = args.filter_num * len(args.filters)
        self.linear = nn.Linear(filter_dim, num_labels)

        if args.fp16:
            self.sigmoid = amp.register_float_function(torch, 'sigmoid')
Пример #6
0
    def __init__(self, args):
        self.args = args
        self.rank = args.local_rank
        self.main_rank = 0

        # fp16 handle
        amp.register_float_function(torch, 'multinomial')

        if args.local_rank != -1:
            torch.cuda.set_device(args.local_rank)
            device = torch.device("cuda", args.local_rank)
            torch.distributed.init_process_group(backend='nccl')
            self.args.device = device
            # get the world size
            self.args.n_gpu = torch.distributed.get_world_size()
        else:
            # no distributed traing
            # set default cuda device
            self.args.device = torch.device("cuda")
            self.args.n_gpu = 1
Пример #7
0
def build_training_mode(config, model, optimizer):
    '''
    Choose model training mode:nn.DataParallel/nn.parallel.DistributedDataParallel,use apex or not
    '''
    if config.distributed:
        if config.sync_bn:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda()
        if config.apex:
            amp.register_float_function(torch, 'sigmoid')
            amp.register_float_function(torch, 'softmax')
            amp.register_float_function(torchvision.ops, 'deform_conv2d')
            model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
            model = apex.parallel.DistributedDataParallel(model,
                                                          delay_allreduce=True)
            if config.sync_bn:
                model = apex.parallel.convert_syncbn_model(model).cuda()
        else:
            local_rank = torch.distributed.get_rank()
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank], output_device=local_rank)
    else:
        if config.apex:
            model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

        model = nn.DataParallel(model)

    return model
Пример #8
0
 def apex_cuda_setting(self, cfg):
     if cfg.APEX.IF_ON:
         logger.info("Using apex")
         try:
             import apex
         except ImportError:
             raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
         assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
         #
         # if cfg.APEX.IF_SYNC_BN:
         #     logger.info("Using apex synced BN")
         #     self.module = apex.parallel.convert_syncbn_model(self.module)
     if self.device is 'cuda':
         self.model = self.model.cuda()
         if cfg.APEX.IF_ON:
             from apex import amp
             amp.register_float_function(torch, 'sigmoid')
             self.model, self.optimizer = amp.initialize(self.model,
                                                         self.optimizer,
                                                         opt_level=cfg.APEX.OPT_LEVEL,
                                                         keep_batchnorm_fp32=None if cfg.APEX.OPT_LEVEL == 'O1' else True,
                                                         loss_scale=cfg.APEX.LOSS_SCALE[0])
Пример #9
0
def main():
    cfg = Config()
    args = parse()

    print('Net Initializing')
    net = CSP(cfg).cuda()

    optimizer = optim.Adam(net.parameters(), lr=cfg.init_lr)
    amp.register_float_function(torch, 'sigmoid')
    net, optimizer = amp.initialize(net, optimizer, opt_level='O1')
    net = nn.DataParallel(net)

    checkpoint = torch.load(args.val_path)
    net.module.load_state_dict(checkpoint['model'])

    # dataset
    print('Dataset...')

    if cfg.val:
        testdataset = CityPersons(path=cfg.root_path, type='val', config=cfg)
        testloader = DataLoader(testdataset, batch_size=1, num_workers=4)

    MRs = val(testloader, net, cfg, args)
Пример #10
0
def main_driver(train_tuple, val_tuple, test_tuple, tokenizer):
    pretrained_config = AutoConfig.from_pretrained(PRETRAINED_MODEL,
                                                   output_hidden_states=True)
    pretrained_base = AutoModel.from_pretrained(
        PRETRAINED_MODEL, config=pretrained_config).cuda()
    classifier = ClassifierHead(pretrained_base).cuda()
    loss_fn = torch.nn.BCELoss()
    opt = torch.optim.Adam(classifier.parameters(), lr=LR)

    amp.register_float_function(torch, 'sigmoid')
    classifier, opt = amp.initialize(classifier,
                                     opt,
                                     opt_level='O1',
                                     verbosity=0)
    list_auc = []

    current_tuple = train_tuple
    for curr_epoch in range(NUM_EPOCHS):
        # After half epochs, switch to training against validation set
        if curr_epoch == NUM_EPOCHS // 2 and len(val_tuple[-1]) > 0:
            current_tuple = val_tuple
        train(classifier, current_tuple, loss_fn, opt, curr_epoch)

        # Score against the validation set
        if len(val_tuple[-1]) > 0:
            epoch_raw_auc = predict_evaluate(classifier,
                                             val_tuple,
                                             curr_epoch,
                                             score=True)
            print('Epoch {} - Val AUC: {:.4f}'.format(curr_epoch,
                                                      epoch_raw_auc))
            list_auc.append(epoch_raw_auc)

        predict_evaluate(classifier, test_tuple, curr_epoch)

    with np.printoptions(precision=4, suppress=True):
        print(np.array(list_auc))
Пример #11
0
    def __init__(self, bottleneck_setting=Mobilefacenet_bottleneck_setting):
        super(MobileFacenet, self).__init__()

        self.conv3 = ConvBlock(3, 64, 3, 2, 1)

        self.dw_conv3 = ConvBlock(64, 64, 3, 1, 1, dw=True)

        self.in_channels = 64
        bottleneck = Bottleneck
        self.bottlenecks = self._make_layer(bottleneck, bottleneck_setting)

        self.conv1 = ConvBlock(128, 512, 1, 1, 0)

        self.linear_GDConv7 = ConvBlock(512,
                                        512,
                                        7,
                                        1,
                                        0,
                                        dw=True,
                                        linear=True)

        self.linear_conv1 = ConvBlock(512, 128, 1, 1, 0, linear=True)

        # parameter init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # kaiming_normal
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        # prevent overflow errors
        if args.use_amp == True:
            amp.register_float_function(torch, 'sigmoid')
            amp.register_float_function(torch, 'softmax')
Пример #12
0
def main():
    # ---------------------------------------------------------
    # Configurations
    # ---------------------------------------------------------

    heavy_augmentation = True  # False to use author's default implementation
    gan_training = False
    mixup_augmentation = False
    fullsize_training = False
    multiscale_training = False
    multi_gpu = True
    mixed_precision_training = True

    model_name = "u2net"  # "u2net", "u2netp", "u2net_heavy"
    se_type = None  # "csse", "sse", "cse", None; None to use author's default implementation
    # checkpoint = "saved_models/u2net/u2net.pth"
    checkpoint = None
    checkpoint_netD = None

    w_adv = 0.2
    w_vgg = 0.2

    train_dirs = [
        "../datasets/sky_segmentation_dataset/datasets/cvprw2020_sky_seg/train/"
    ]
    train_dirs_file_limit = [
        None,
    ]

    image_ext = '.jpg'
    label_ext = '.png'
    dataset_name = "cvprw2020_sky_seg"

    lr = 0.0003
    epoch_num = 500
    batch_size_train = 48
    # batch_size_val = 1
    workers = 16
    save_frq = 1000  # save the model every 2000 iterations

    save_debug_samples = False
    debug_samples_dir = "./debug/"

    # ---------------------------------------------------------

    model_dir = './saved_models/' + model_name + '/'
    os.makedirs(model_dir, exist_ok=True)

    writer = SummaryWriter()

    if fullsize_training:
        batch_size_train = 1
        multiscale_training = False

    # ---------------------------------------------------------
    # 1. Construct data input pipeline
    # ---------------------------------------------------------

    # Get dataset name
    dataset_name = dataset_name.replace(" ", "_")

    # Get training data
    assert len(train_dirs) == len(train_dirs_file_limit), \
        "Different train dirs and train dirs file limit length!"

    tra_img_name_list = []
    tra_lbl_name_list = []
    for d, flimit in zip(train_dirs, train_dirs_file_limit):
        img_files = glob.glob(d + '**/*' + image_ext, recursive=True)
        if flimit:
            img_files = np.random.choice(img_files, size=flimit, replace=False)

        print(f"directory: {d}, files: {len(img_files)}")

        for img_path in img_files:
            lbl_path = img_path.replace("/image/", "/alpha/") \
                .replace(image_ext, label_ext)

            if os.path.exists(img_path) and os.path.exists(lbl_path):
                assert os.path.splitext(
                    os.path.basename(img_path))[0] == os.path.splitext(
                        os.path.basename(lbl_path))[0], "Wrong filename."

                tra_img_name_list.append(img_path)
                tra_lbl_name_list.append(lbl_path)
            else:
                print(
                    f"Warning, dropping sample {img_path} because label file {lbl_path} not found!"
                )

    tra_img_name_list, tra_lbl_name_list = shuffle(tra_img_name_list,
                                                   tra_lbl_name_list)

    train_num = len(tra_img_name_list)
    # val_num = 0  # unused
    print(f"dataset name        : {dataset_name}")
    print(f"training samples    : {train_num}")

    # Construct data input pipeline
    if heavy_augmentation:
        transform = AlbuSampleTransformer(
            get_heavy_transform(
                fullsize_training=fullsize_training,
                transform_size=False if
                (fullsize_training or multiscale_training) else True))
    else:
        transform = transforms.Compose([
            RescaleT(320),
            RandomCrop(288),
        ])

    # Create dataset and dataloader
    dataset_kwargs = dict(img_name_list=tra_img_name_list,
                          lbl_name_list=tra_lbl_name_list,
                          transform=transforms.Compose([
                              transform,
                          ] + ([
                              SaveDebugSamples(out_dir=debug_samples_dir),
                          ] if save_debug_samples else []) + ([
                              ToTensorLab(flag=0),
                          ] if not multiscale_training else [])))
    if mixup_augmentation:
        _dataset_cls = MixupAugSalObjDataset
    else:
        _dataset_cls = SalObjDataset

    salobj_dataset = _dataset_cls(**dataset_kwargs)
    salobj_dataloader = DataLoader(
        salobj_dataset,
        batch_size=batch_size_train,
        collate_fn=multi_scale_collater if multiscale_training else None,
        shuffle=True,
        pin_memory=True,
        num_workers=workers)

    # ---------------------------------------------------------
    # 2. Load model
    # ---------------------------------------------------------

    # Instantiate model
    if model_name == "u2net":
        net = U2NET(3, 1, se_type=se_type)
    elif model_name == "u2netp":
        net = U2NETP(3, 1, se_type=se_type)
    elif model_name == "u2net_heavy":
        net = u2net_heavy()
    elif model_name == "custom":
        net = CustomNet()
    else:
        raise ValueError(f"Unknown model_name: {model_name}")

    # Restore model weights from checkpoint
    if checkpoint:
        if not os.path.exists(checkpoint):
            raise FileNotFoundError(f"Checkpoint file not found: {checkpoint}")

        try:
            print(f"Restoring from checkpoint: {checkpoint}")
            net.load_state_dict(torch.load(checkpoint, map_location="cpu"))
            print(" - [x] success")
        except:
            print(" - [!] error")

    if torch.cuda.is_available():
        net.cuda()

    if gan_training:
        netD = MultiScaleNLayerDiscriminator()

        if checkpoint_netD:
            if not os.path.exists(checkpoint_netD):
                raise FileNotFoundError(
                    f"Discriminator checkpoint file not found: {checkpoint_netD}"
                )

            try:
                print(
                    f"Restoring discriminator from checkpoint: {checkpoint_netD}"
                )
                netD.load_state_dict(
                    torch.load(checkpoint_netD, map_location="cpu"))
                print(" - [x] success")
            except:
                print(" - [!] error")

        if torch.cuda.is_available():
            netD.cuda()

        vgg19 = VGG19Features()
        vgg19.eval()
        if torch.cuda.is_available():
            vgg19 = vgg19.cuda()

    # ---------------------------------------------------------
    # 3. Define optimizer
    # ---------------------------------------------------------

    optimizer = optim.Adam(net.parameters(),
                           lr=lr,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0)
    # optimizer = optim.SGD(net.parameters(), lr=lr)
    # scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr/4, max_lr=lr,
    #                                         mode="triangular2",
    #                                         step_size_up=2 * len(salobj_dataloader))

    if gan_training:
        optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.9))

    # ---------------------------------------------------------
    # 4. Initialize AMP and data parallel stuffs
    # ---------------------------------------------------------

    GOT_AMP = False
    if mixed_precision_training:
        try:
            print("Checking for Apex AMP support...")
            from apex import amp
            GOT_AMP = True
            print(" - [x] yes")
        except ImportError:
            print(" - [!] no")

    if GOT_AMP:
        amp.register_float_function(torch, 'sigmoid')
        net, optimizer = amp.initialize(net, optimizer, opt_level="O1")

        if gan_training:
            netD, optimizerD = amp.initialize(netD, optimizerD, opt_level="O1")
            vgg19 = amp.initialize(vgg19, opt_level="O1")

    if torch.cuda.device_count() > 1 and multi_gpu:
        print(f"Multi-GPU training using {torch.cuda.device_count()} GPUs.")
        net = nn.DataParallel(net)

        if gan_training:
            netD = nn.DataParallel(netD)
            vgg19 = nn.DataParallel(vgg19)
    else:
        print(f"Training using {torch.cuda.device_count()} GPUs.")

    # ---------------------------------------------------------
    # 5. Training
    # ---------------------------------------------------------

    print("Start training...")

    ite_num = 0
    ite_num4val = 0
    running_loss = 0.0
    running_bce_loss = 0.0
    running_tar_loss = 0.0
    running_adv_loss = 0.0
    running_per_loss = 0.0
    running_fake_loss = 0.0
    running_real_loss = 0.0
    running_lossD = 0.0

    for epoch in tqdm(range(0, epoch_num), desc="All epochs"):
        net.train()
        if gan_training:
            netD.train()

        for i, data in enumerate(
                tqdm(salobj_dataloader, desc=f"Epoch #{epoch}")):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1

            image_key = "image"
            label_key = "label"
            inputs, labels = data[image_key], data[label_key]
            # tqdm.write(f"input tensor shape: {inputs.shape}")

            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)

            # Wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v = \
                    Variable(inputs.cuda(), requires_grad=False), \
                    Variable(labels.cuda(), requires_grad=False)
            else:
                inputs_v, labels_v = \
                    Variable(inputs, requires_grad=False), \
                    Variable(labels, requires_grad=False)

            # # Zero the parameter gradients
            # optimizer.zero_grad()

            # Forward + backward + optimize

            d6 = 0
            if model_name == "custom":
                d0, d1, d2, d3, d4, d5 = net(inputs_v)
            else:
                d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)

            if gan_training:
                optimizerD.zero_grad()

                dis_fake = netD(inputs_v, d0.detach())
                dis_real = netD(inputs_v, labels_v)

                loss_fake = bce_with_logits_loss(dis_fake,
                                                 torch.zeros_like(dis_fake))
                loss_real = bce_with_logits_loss(dis_real,
                                                 torch.ones_like(dis_real))
                lossD = loss_fake + loss_real

                if GOT_AMP:
                    with amp.scale_loss(lossD, optimizerD) as scaled_loss:
                        scaled_loss.backward()
                else:
                    lossD.backward()

                optimizerD.step()

                writer.add_scalar("lossD/fake", loss_fake.item(), ite_num)
                writer.add_scalar("lossD/real", loss_real.item(), ite_num)
                writer.add_scalar("lossD/sum", lossD.item(), ite_num)
                running_fake_loss += loss_fake.item()
                running_real_loss += loss_real.item()
                running_lossD += lossD.item()

            # Zero the parameter gradients
            optimizer.zero_grad()

            if model_name == "custom":
                loss2, loss = multi_bce_loss_fusion5(d0, d1, d2, d3, d4, d5,
                                                     labels_v)
            else:
                loss2, loss = multi_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6,
                                                    labels_v)

            writer.add_scalar("lossG/bce", loss.item(), ite_num)
            running_bce_loss += loss.item()

            if gan_training:
                # Adversarial loss
                loss_adv = 0.0
                if w_adv:
                    dis_fake = netD(inputs_v, d0)
                    loss_adv = bce_with_logits_loss(dis_fake,
                                                    torch.ones_like(dis_fake))

                # Perceptual loss
                loss_per = 0.0
                if w_vgg:
                    vgg19_fm_pred = vgg19(inputs_v * d0)
                    vgg19_fm_label = vgg19(inputs_v * labels_v)
                    loss_per = mae_loss(vgg19_fm_pred, vgg19_fm_label)

                loss = loss + w_adv * loss_adv + w_vgg * loss_per

            if GOT_AMP:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            # scheduler.step()

            writer.add_scalar("lossG/sum", loss.item(), ite_num)
            writer.add_scalar("lossG/loss2", loss2.item(), ite_num)
            running_loss += loss.item()
            running_tar_loss += loss2.item()
            if gan_training:
                writer.add_scalar("lossG/adv", loss_adv.item(), ite_num)
                writer.add_scalar("lossG/perceptual", loss_per.item(), ite_num)
                running_adv_loss += loss_adv.item()
                running_per_loss += loss_per.item()

            if ite_num % 200 == 0:
                writer.add_images("inputs", inv_normalize(inputs_v), ite_num)
                writer.add_images("labels", labels_v, ite_num)
                writer.add_images("preds", d0, ite_num)

            # Delete temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss
            if gan_training:
                del dis_fake, dis_real, loss_fake, loss_real, lossD, loss_adv, vgg19_fm_pred, vgg19_fm_label, loss_per

            # Print stats
            tqdm.write(
                "[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train G/sum: %3f, G/bce: %3f, G/bce_tar: %3f, G/adv: %3f, G/percept: %3f, D/fake: %3f, D/real: %3f, D/sum: %3f"
                % (epoch + 1, epoch_num,
                   (i + 1) * batch_size_train, train_num, ite_num,
                   running_loss / ite_num4val, running_bce_loss / ite_num4val,
                   running_tar_loss / ite_num4val, running_adv_loss /
                   ite_num4val, running_per_loss / ite_num4val,
                   running_fake_loss / ite_num4val, running_real_loss /
                   ite_num4val, running_lossD / ite_num4val))

            if ite_num % save_frq == 0:
                # Save checkpoint
                torch.save(
                    net.module.state_dict() if hasattr(
                        net, "module") else net.state_dict(), model_dir +
                    model_name + (("_" + se_type) if se_type else "") +
                    ("_" + dataset_name) +
                    ("_mixup_aug" if mixup_augmentation else "") +
                    ("_heavy_aug" if heavy_augmentation else "") +
                    ("_fullsize" if fullsize_training else "") +
                    ("_multiscale" if multiscale_training else "") +
                    "_bce_itr_%d_train_%3f_tar_%3f.pth" %
                    (ite_num, running_loss / ite_num4val,
                     running_tar_loss / ite_num4val))

                if gan_training:
                    torch.save(
                        netD.module.state_dict() if hasattr(netD, "module")
                        else netD.state_dict(), model_dir + "netD_" +
                        model_name + (("_" + se_type) if se_type else "") +
                        ("_" + dataset_name) +
                        ("_mixup_aug" if mixup_augmentation else "") +
                        ("_heavy_aug" if heavy_augmentation else "") +
                        ("_fullsize" if fullsize_training else "") +
                        ("_multiscale" if multiscale_training else "") +
                        "itr_%d.pth" % (ite_num))

                # Reset stats
                running_loss = 0.0
                running_bce_loss = 0.0
                running_tar_loss = 0.0
                running_adv_loss = 0.0
                running_per_loss = 0.0
                running_fake_loss = 0.0
                running_real_loss = 0.0
                running_lossD = 0.0
                ite_num4val = 0

                net.train()  # resume train
                if gan_training:
                    netD.train()

    writer.close()
    print("Training completed successfully.")
Пример #13
0
def main():
    args = parse_args()
    global local_rank
    local_rank = args.local_rank
    if local_rank == 0:
        global logger
        logger = get_logger(__name__, args.log)

    torch.cuda.empty_cache()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')
    global gpus_num
    gpus_num = torch.cuda.device_count()
    if local_rank == 0:
        logger.info(f'use {gpus_num} gpus')
        logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()

    # dataset and dataloader
    if local_rank == 0:
        logger.info('start loading data')
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        Config.train_dataset, shuffle=True)
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.per_node_batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              collate_fn=collater,
                              sampler=train_sampler)
    if local_rank == 0:
        logger.info('finish loading data')

    model = retinanet.__dict__[args.network](**{
        "pretrained": args.pretrained,
        "num_classes": args.num_classes,
    })

    for name, param in model.named_parameters():
        if local_rank == 0:
            logger.info(f"{name},{param.requires_grad}")

    flops_input = torch.randn(1, 3, args.input_image_size,
                              args.input_image_size)
    flops, params = profile(model, inputs=(flops_input, ))
    flops, params = clever_format([flops, params], "%.3f")
    if local_rank == 0:
        logger.info(
            f"model: '{args.network}', flops: {flops}, params: {params}")

    criterion = RetinaLoss(image_w=args.input_image_size,
                           image_h=args.input_image_size).cuda()
    decoder = RetinaDecoder(image_w=args.input_image_size,
                            image_h=args.input_image_size).cuda()

    model = model.cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    if args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if args.apex:
        amp.register_float_function(torch, 'sigmoid')
        amp.register_float_function(torch, 'softmax')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = apex.parallel.DistributedDataParallel(model,
                                                      delay_allreduce=True)
        if args.sync_bn:
            model = apex.parallel.convert_syncbn_model(model)
    else:
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[local_rank],
                                                    output_device=local_rank)

    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            if local_rank == 0:
                logger.exception(
                    '{} is not a file, please check it again'.format(
                        args.resume))
            sys.exit(-1)
        if local_rank == 0:
            logger.info('start only evaluating')
            logger.info(f"start resuming model from {args.evaluate}")
        checkpoint = torch.load(args.evaluate,
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        if local_rank == 0:
            logger.info(f"start eval.")
            all_eval_result = validate(Config.val_dataset, model, decoder)
            logger.info(f"eval done.")
            if all_eval_result is not None:
                logger.info(
                    f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
                )

        return

    best_map = 0.0
    start_epoch = 1
    # resume training
    if os.path.exists(args.resume):
        if local_rank == 0:
            logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        best_map = checkpoint['best_map']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        if local_rank == 0:
            logger.info(
                f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
                f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}"
            )

    if local_rank == 0:
        if not os.path.exists(args.checkpoints):
            os.makedirs(args.checkpoints)

    if local_rank == 0:
        logger.info('start training')
    for epoch in range(start_epoch, args.epochs + 1):
        train_sampler.set_epoch(epoch)
        cls_losses, reg_losses, losses = train(train_loader, model, criterion,
                                               optimizer, scheduler, epoch,
                                               args)
        if local_rank == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, loss: {losses:.2f}"
            )

        if epoch % 5 == 0 or epoch == args.epochs:
            if local_rank == 0:
                logger.info(f"start eval.")
                all_eval_result = validate(Config.val_dataset, model, decoder)
                logger.info(f"eval done.")
                if all_eval_result is not None:
                    logger.info(
                        f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
                    )
                    if all_eval_result[0] > best_map:
                        torch.save(model.module.state_dict(),
                                   os.path.join(args.checkpoints, "best.pth"))
                        best_map = all_eval_result[0]
        if local_rank == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'best_map': best_map,
                    'cls_loss': cls_losses,
                    'reg_loss': reg_losses,
                    'loss': losses,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }, os.path.join(args.checkpoints, 'latest.pth'))

    if local_rank == 0:
        logger.info(f"finish training, best_map: {best_map:.3f}")
    training_time = (time.time() - start_time) / 3600
    if local_rank == 0:
        logger.info(
            f"finish training, total training time: {training_time:.2f} hours")
Пример #14
0
import apex
from apex import amp
from apex.parallel import DistributedDataParallel as DDP

import segmentation_models_pytorch as smp
from model.dinknet import DinkNet34, DinkNet50, DinkNet101

from config import cfg
from utils import *
from dataset import AgriTrainDataset, AgriValDataset
from model.deeplab import DeepLab
from model.loss import ComposedLossWithLogits

torch.manual_seed(42)
np.random.seed(42)
amp.register_float_function(torch, 'sigmoid')

INF_FP16 = 2 ** 15

def parse_args():
    parser = argparse.ArgumentParser(
        description="PyTorch Semantic Segmentation Training"
    )

    parser.add_argument(
        "--local_rank",
        default=0,
        type=int
    )

    parser.add_argument(
Пример #15
0
def main(args):

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    args_init = copy.deepcopy(args)
    args_teacher = copy.deepcopy(args)

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k not in ['resume', 'local_rank']:
                setattr(args, k, v)

    args = compute_subsampling_factor(args)
    resume_epoch = int(args.resume.split('-')[-1]) if args.resume else 0

    # Load dataset
    train_set = build_dataloader(args=args,
                                 tsv_path=args.train_set,
                                 tsv_path_sub1=args.train_set_sub1,
                                 tsv_path_sub2=args.train_set_sub2,
                                 batch_size=args.batch_size,
                                 batch_size_type=args.batch_size_type,
                                 max_n_frames=args.max_n_frames,
                                 resume_epoch=resume_epoch,
                                 sort_by=args.sort_by,
                                 short2long=args.sort_short2long,
                                 sort_stop_epoch=args.sort_stop_epoch,
                                 num_workers=args.workers,
                                 pin_memory=args.pin_memory,
                                 distributed=args.distributed,
                                 word_alignment_dir=args.train_word_alignment,
                                 ctc_alignment_dir=args.train_ctc_alignment)
    dev_set = build_dataloader(
        args=args,
        tsv_path=args.dev_set,
        tsv_path_sub1=args.dev_set_sub1,
        tsv_path_sub2=args.dev_set_sub2,
        batch_size=1 if 'transducer' in args.dec_type else args.batch_size,
        batch_size_type='seq'
        if 'transducer' in args.dec_type else args.batch_size_type,
        max_n_frames=1600,
        word_alignment_dir=args.dev_word_alignment,
        ctc_alignment_dir=args.dev_ctc_alignment)
    eval_sets = [
        build_dataloader(args=args, tsv_path=s, batch_size=1, is_test=True)
        for s in args.eval_sets
    ]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.input_dim = train_set.input_dim

    # Set save path
    if args.resume:
        args.save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(args.save_path)
    else:
        dir_name = set_asr_model_name(args)
        if args.mbr_training:
            assert args.asr_init
            args.save_path = mkdir_join(os.path.dirname(args.asr_init),
                                        dir_name)
        else:
            args.save_path = mkdir_join(
                args.model_save_dir,
                '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
                dir_name)
        if args.local_rank > 0:
            time.sleep(1)
        args.save_path = set_save_path(args.save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(args.save_path, 'train.log'), args.stdout,
               args.local_rank)

    # Load a LM conf file for LM fusion & LM initialization
    if not args.resume and args.external_lm:
        lm_conf = load_config(
            os.path.join(os.path.dirname(args.external_lm), 'conf.yml'))
        args.lm_conf = argparse.Namespace()
        for k, v in lm_conf.items():
            setattr(args.lm_conf, k, v)
        assert args.unit == args.lm_conf.unit
        assert args.vocab == args.lm_conf.vocab

    # Model setting
    model = Speech2Text(args, args.save_path, train_set.idx2token[0])

    if not args.resume:
        # Save nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(args.save_path,
                                                  'nlsyms.txt'))
        for sub in ['', '_sub1', '_sub2']:
            if args.get('dict' + sub):
                shutil.copy(
                    args.get('dict' + sub),
                    os.path.join(args.save_path, 'dict' + sub + '.txt'))
            if args.get('unit' + sub) == 'wp':
                shutil.copy(
                    args.get('wp_model' + sub),
                    os.path.join(args.save_path, 'wp' + sub + '.model'))

        for k, v in sorted(args.items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info('torch version: %s' % str(torch.__version__))
        logger.info(model)

        # Initialize with pre-trained model's parameters
        if args.asr_init:
            # Load ASR model (full model)
            conf_init = load_config(
                os.path.join(os.path.dirname(args.asr_init), 'conf.yml'))
            for k, v in conf_init.items():
                setattr(args_init, k, v)
            model_init = Speech2Text(args_init)
            load_checkpoint(args.asr_init, model_init)

            # Overwrite parameters
            param_dict = dict(model_init.named_parameters())
            for n, p in model.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if args.asr_init_enc_only and 'enc' not in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)

    # Set optimizer
    optimizer = set_optimizer(
        model,
        'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer,
        args.lr, args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = 'former' in args.enc_type or 'former' in args.dec_type or 'former' in args.dec_type_sub1
    scheduler = LRScheduler(
        optimizer,
        args.lr,
        decay_type=args.lr_decay_type,
        decay_start_epoch=args.lr_decay_start_epoch,
        decay_rate=args.lr_decay_rate,
        decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
        early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
        lower_better=args.metric not in ['accuracy', 'bleu'],
        warmup_start_lr=args.warmup_start_lr,
        warmup_n_steps=args.warmup_n_steps,
        peak_lr=0.05 / (args.get('transformer_enc_d_model', 0)**0.5)
        if 'conformer' in args.enc_type else 1e6,
        model_size=args.get('transformer_enc_d_model',
                            args.get('transformer_dec_d_model', 0)),
        factor=args.lr_factor,
        noam=args.optimizer == 'noam',
        save_checkpoints_topk=10 if is_transformer else 1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, scheduler)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            scheduler.convert_to_sgd(model,
                                     args.lr,
                                     args.weight_decay,
                                     decay_type='always',
                                     decay_rate=0.5)

    # Load teacher ASR model
    teacher = None
    if args.teacher:
        assert os.path.isfile(args.teacher), 'There is no checkpoint.'
        conf_teacher = load_config(
            os.path.join(os.path.dirname(args.teacher), 'conf.yml'))
        for k, v in conf_teacher.items():
            setattr(args_teacher, k, v)
        # Setting for knowledge distillation
        args_teacher.ss_prob = 0
        args.lsm_prob = 0
        teacher = Speech2Text(args_teacher)
        load_checkpoint(args.teacher, teacher)

    # Load teacher LM
    teacher_lm = None
    if args.teacher_lm:
        assert os.path.isfile(args.teacher_lm), 'There is no checkpoint.'
        conf_lm = load_config(
            os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml'))
        args_lm = argparse.Namespace()
        for k, v in conf_lm.items():
            setattr(args_lm, k, v)
        teacher_lm = build_lm(args_lm)
        load_checkpoint(args.teacher_lm, teacher_lm)

    # GPU setting
    args.use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp, scaler = None, None
    if args.n_gpus >= 1:
        model.cudnn_setting(
            deterministic=((not is_transformer) and (not args.cudnn_benchmark))
            or args.cudnn_deterministic,
            benchmark=(not is_transformer) and args.cudnn_benchmark)

        # Mixed precision training setting
        if args.use_apex:
            if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
                scaler = torch.cuda.amp.GradScaler()
            else:
                from apex import amp
                model, scheduler.optimizer = amp.initialize(
                    model, scheduler.optimizer, opt_level=args.train_dtype)
                from neural_sp.models.seq2seq.decoders.ctc import CTC
                amp.register_float_function(CTC, "loss_fn")
                # NOTE: see https://github.com/espnet/espnet/pull/1779
                amp.init()
                if args.resume:
                    load_checkpoint(args.resume, amp=amp)

        n = torch.cuda.device_count() // args.local_world_size
        device_ids = list(range(args.local_rank * n,
                                (args.local_rank + 1) * n))

        torch.cuda.set_device(device_ids[0])
        model.cuda(device_ids[0])
        scheduler.cuda(device_ids[0])
        if args.distributed:
            model = DDP(model, device_ids=device_ids)
        else:
            model = CustomDataParallel(model,
                                       device_ids=list(range(args.n_gpus)))

        if teacher is not None:
            teacher.cuda()
        if teacher_lm is not None:
            teacher_lm.cuda()
    else:
        model = CPUWrapperASR(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(args, model, args.local_rank)
    args.wandb_id = reporter.wandb_id
    if args.resume:
        n_steps = scheduler.n_steps * max(
            1, args.accum_grad_n_steps // args.local_world_size)
        reporter.resume(n_steps, resume_epoch)

    # Save conf file as a yaml file
    if args.local_rank == 0:
        save_config(args, os.path.join(args.save_path, 'conf.yml'))
        if args.external_lm:
            save_config(args.lm_conf,
                        os.path.join(args.save_path, 'conf_lm.yml'))
        # NOTE: save after reporter for wandb ID

    # Define tasks
    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = []
        if args.total_weight - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0:
            tasks += ['ys']
        if args.bwd_weight > 0:
            tasks = ['ys.bwd'] + tasks
        if args.ctc_weight > 0:
            tasks = ['ys.ctc'] + tasks
        if args.mbr_ce_weight > 0:
            tasks = ['ys.mbr'] + tasks
        for sub in ['sub1', 'sub2']:
            if args.get('train_set_' + sub) is not None:
                if args.get(sub + '_weight', 0) - args.get(
                        'ctc_weight_' + sub, 0) > 0:
                    tasks = ['ys_' + sub] + tasks
                if args.get('ctc_weight_' + sub, 0) > 0:
                    tasks = ['ys_' + sub + '.ctc'] + tasks
    else:
        tasks = ['all']

    if args.get('ss_start_epoch', 0) <= resume_epoch:
        model.module.trigger_scheduled_sampling()
    if args.get('mocha_quantity_loss_start_epoch', 0) <= resume_epoch:
        model.module.trigger_quantity_loss()

    start_time_train = time.time()
    for ep in range(resume_epoch, args.n_epochs):
        train_one_epoch(model, train_set, dev_set, eval_sets, scheduler,
                        reporter, logger, args, amp, scaler, tasks, teacher,
                        teacher_lm)

        # Save checkpoint and validate model per epoch
        if reporter.n_epochs + 1 < args.eval_start_epoch:
            scheduler.epoch()  # lr decay
            reporter.epoch()  # plot

            # Save model
            if args.local_rank == 0:
                scheduler.save_checkpoint(model,
                                          args.save_path,
                                          amp=amp,
                                          remove_old=(not is_transformer)
                                          and args.remove_old_checkpoints)
        else:
            start_time_eval = time.time()
            # dev
            metric_dev = validate([model.module], dev_set, args,
                                  reporter.n_epochs + 1, logger)
            scheduler.epoch(metric_dev)  # lr decay
            reporter.epoch(metric_dev, name=args.metric)  # plot
            reporter.add_scalar('dev/' + args.metric, metric_dev)

            if scheduler.is_topk or is_transformer:
                # Save model
                if args.local_rank == 0:
                    scheduler.save_checkpoint(model,
                                              args.save_path,
                                              amp=amp,
                                              remove_old=(not is_transformer)
                                              and args.remove_old_checkpoints)

                # test
                if scheduler.is_topk:
                    for eval_set in eval_sets:
                        validate([model.module], eval_set, args,
                                 reporter.n_epochs, logger)

            logger.info('Evaluation time: %.2f min' %
                        ((time.time() - start_time_eval) / 60))

            # Early stopping
            if scheduler.is_early_stop:
                break

            # Convert to fine-tuning stage
            if reporter.n_epochs == args.convert_to_sgd_epoch:
                scheduler.convert_to_sgd(model,
                                         args.lr,
                                         args.weight_decay,
                                         decay_type='always',
                                         decay_rate=0.5)

        if reporter.n_epochs >= args.n_epochs:
            break
        if args.get('ss_start_epoch', 0) == (ep + 1):
            model.module.trigger_scheduled_sampling()
        if args.get('mocha_quantity_loss_start_epoch', 0) == (ep + 1):
            model.module.trigger_quantity_loss()

    logger.info('Total time: %.2f hour' %
                ((time.time() - start_time_train) / 3600))
    reporter.close()

    return args.save_path
Пример #16
0
    from factor_multiply_fast import butterfly_multiply_untied_forward_backward_fast
    from factor_multiply_fast import butterfly_bbs_multiply_untied_forward_fast
    from factor_multiply_fast import butterfly_bbs_multiply_untied_forward_backward_fast
    from factor_multiply_fast import butterfly_odo_multiply_untied_forward_fast
    from factor_multiply_fast import butterfly_odo_multiply_untied_backward_fast
    from factor_multiply_fast import butterfly_odo_multiply_untied_forward_backward_fast
except:
    use_extension = False
    import warnings
    warnings.warn(
        "C++/CUDA extension isn't installed properly. Will use butterfly multiply implemented in Pytorch, which is much slower."
    )

try:
    from apex import amp
    amp.register_float_function(fmf,
                                'butterfly_odo_multiply_untied_forward_fast')
    amp.register_float_function(
        fmf, 'butterfly_odo_multiply_untied_forward_backward_fast')
except ImportError:
    raise ImportError(
        "Please install apex from https://www.github.com/nvidia/apex.")


def butterfly_mult_torch(twiddle,
                         input,
                         increasing_stride=True,
                         return_intermediates=False):
    """
    Parameters:
        twiddle: (nstack, n - 1, 2, 2) if real or (nstack, n - 1, 2, 2, 2) if complex
        input: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
def main():
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    cur_timestamp = str(datetime.now())[:-3]  # we also include ms to prevent the probability of name collision
    model_width = {'linear': '', 'cnn': args.n_filters_cnn, 'lenet': '', 'resnet18': ''}[args.model]
    model_str = '{}{}'.format(args.model, model_width)
    model_name = '{} dataset={} model={} eps={} attack={} m={} attack_init={} fgsm_alpha={} epochs={} pgd={}-{} grad_align_cos_lambda={} lr_max={} seed={}'.format(
        cur_timestamp, args.dataset, model_str, args.eps, args.attack, args.minibatch_replay, args.attack_init, args.fgsm_alpha, args.epochs,
        args.pgd_alpha_train, args.pgd_train_n_iters, args.grad_align_cos_lambda, args.lr_max, args.seed)
    if not os.path.exists('models'):
        os.makedirs('models')
    logger = utils.configure_logger(model_name, args.debug)
    logger.info(args)
    half_prec = args.half_prec
    n_cls = 2 if 'binary' in args.dataset else 10

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    double_bp = True if args.grad_align_cos_lambda > 0 else False
    n_eval_every_k_iter = args.n_eval_every_k_iter
    args.pgd_alpha = args.eps / 4

    eps, pgd_alpha, pgd_alpha_train = args.eps / 255, args.pgd_alpha / 255, args.pgd_alpha_train / 255
    train_data_augm = False if args.dataset in ['mnist'] else True
    train_batches = data.get_loaders(args.dataset, -1, args.batch_size, train_set=True, shuffle=True, data_augm=train_data_augm)
    train_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size, train_set=True, shuffle=False, data_augm=False)
    test_batches = data.get_loaders(args.dataset, args.n_final_eval, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)
    test_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)

    model = models.get_model(args.model, n_cls, half_prec, data.shapes_dict[args.dataset], args.n_filters_cnn).cuda()
    model.apply(utils.initialize_weights)
    model.train()

    if args.model == 'resnet18':
        opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay)
    elif args.model == 'cnn':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    elif args.model == 'lenet':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    else:
        raise ValueError('decide about the right optimizer for the new model')

    if half_prec:
        if double_bp:
            amp.register_float_function(torch, 'batch_norm')
        model, opt = amp.initialize(model, opt, opt_level="O1")

    if args.attack == 'fgsm':  # needed here only for Free-AT
        delta = torch.zeros(args.batch_size, *data.shapes_dict[args.dataset][1:]).cuda()
        delta.requires_grad = True

    lr_schedule = utils.get_lr_schedule(args.lr_schedule, args.epochs, args.lr_max)
    loss_function = nn.CrossEntropyLoss()

    train_acc_pgd_best, best_state_dict = 0.0, copy.deepcopy(model.state_dict())
    start_time = time.time()
    time_train, iteration, best_iteration = 0, 0, 0
    for epoch in range(args.epochs + 1):
        train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0
        for i, (X, y) in enumerate(train_batches):
            if i % args.minibatch_replay != 0 and i > 0:  # take new inputs only each `minibatch_replay` iterations
                X, y = X_prev, y_prev
            time_start_iter = time.time()
            # epoch=0 runs only for one iteration (to check the training stats at init)
            if epoch == 0 and i > 0:
                break
            X, y = X.cuda(), y.cuda()
            lr = lr_schedule(epoch - 1 + (i + 1) / len(train_batches))  # epoch - 1 since the 0th epoch is skipped
            opt.param_groups[0].update(lr=lr)

            if args.attack in ['pgd', 'pgd_corner']:
                pgd_rs = True if args.attack_init == 'random' else False
                n_eps_warmup_epochs = 5
                n_iterations_max_eps = n_eps_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                eps_pgd_train = min(iteration / n_iterations_max_eps * eps, eps) if args.dataset == 'svhn' else eps
                delta = utils.attack_pgd_training(
                    model, X, y, eps_pgd_train, pgd_alpha_train, opt, half_prec, args.pgd_train_n_iters, rs=pgd_rs)
                if args.attack == 'pgd_corner':
                    delta = eps * utils.sign(delta)  # project to the corners
                    delta = clamp(X + delta, 0, 1) - X

            elif args.attack == 'fgsm':
                if args.minibatch_replay == 1:
                    if args.attack_init == 'zero':
                        delta = torch.zeros_like(X, requires_grad=True)
                    elif args.attack_init == 'random':
                        delta = utils.get_uniform_delta(X.shape, eps, requires_grad=True)
                    else:
                        raise ValueError('wrong args.attack_init')
                else:  # if Free-AT, we just reuse the existing delta from the previous iteration
                    delta.requires_grad = True

                X_adv = clamp(X + delta, 0, 1)
                output = model(X_adv)
                loss = F.cross_entropy(output, y)
                if half_prec:
                    with amp.scale_loss(loss, opt) as scaled_loss:
                        grad = torch.autograd.grad(scaled_loss, delta, create_graph=True if double_bp else False)[0]
                        grad /= scaled_loss / loss  # reverse back the scaling
                else:
                    grad = torch.autograd.grad(loss, delta, create_graph=True if double_bp else False)[0]

                grad = grad.detach()

                argmax_delta = eps * utils.sign(grad)

                n_alpha_warmup_epochs = 5
                n_iterations_max_alpha = n_alpha_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                fgsm_alpha = min(iteration / n_iterations_max_alpha * args.fgsm_alpha, args.fgsm_alpha) if args.dataset == 'svhn' else args.fgsm_alpha
                delta.data = clamp(delta.data + fgsm_alpha * argmax_delta, -eps, eps)
                delta.data = clamp(X + delta.data, 0, 1) - X

            elif args.attack == 'random_corner':
                delta = utils.get_uniform_delta(X.shape, eps, requires_grad=False)
                delta = eps * utils.sign(delta)

            elif args.attack == 'none':
                delta = torch.zeros_like(X, requires_grad=False)
            else:
                raise ValueError('wrong args.attack')

            # extra FP+BP to calculate the gradient to monitor it
            if args.attack in ['none', 'random_corner', 'pgd', 'pgd_corner']:
                grad = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='none',
                                      backprop=args.grad_align_cos_lambda != 0.0)

            delta = delta.detach()

            output = model(X + delta)
            loss = loss_function(output, y)

            reg = torch.zeros(1).cuda()[0]  # for .item() to run correctly
            if args.grad_align_cos_lambda != 0.0:
                grad2 = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='random_uniform', backprop=True)
                grads_nnz_idx = ((grad**2).sum([1, 2, 3])**0.5 != 0) * ((grad2**2).sum([1, 2, 3])**0.5 != 0)
                grad1, grad2 = grad[grads_nnz_idx], grad2[grads_nnz_idx]
                grad1_norms, grad2_norms = l2_norm_batch(grad1), l2_norm_batch(grad2)
                grad1_normalized = grad1 / grad1_norms[:, None, None, None]
                grad2_normalized = grad2 / grad2_norms[:, None, None, None]
                cos = torch.sum(grad1_normalized * grad2_normalized, (1, 2, 3))
                reg += args.grad_align_cos_lambda * (1.0 - cos.mean())

            loss += reg

            if epoch != 0:
                opt.zero_grad()
                utils.backward(loss, opt, half_prec)
                opt.step()

            time_train += time.time() - time_start_iter
            train_loss += loss.item() * y.size(0)
            train_reg += reg.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

            with torch.no_grad():  # no grad for the stats
                grad_norm_x += l2_norm_batch(grad).sum().item()
                delta_final = clamp(X + delta, 0, 1) - X  # we should measure delta after the projection onto [0, 1]^d
                avg_delta_l2 += ((delta_final ** 2).sum([1, 2, 3]) ** 0.5).sum().item()

            if iteration % args.eval_iter_freq == 0:
                train_loss, train_reg = train_loss / train_n, train_reg / train_n
                train_acc, avg_delta_l2 = train_acc / train_n, avg_delta_l2 / train_n

                # it'd be incorrect to recalculate the BN stats on the test sets and for clean / adversarial points
                utils.model_eval(model, half_prec)

                test_acc_clean, _, _ = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_fgsm, test_loss_fgsm, fgsm_deltas = rob_acc(test_batches_fast, model, eps, eps, opt, half_prec, 1, 1, rs=False)
                test_acc_pgd, test_loss_pgd, pgd_deltas = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)
                cos_fgsm_pgd = utils.avg_cos_np(fgsm_deltas, pgd_deltas)
                train_acc_pgd, _, _ = rob_acc(train_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)  # needed for early stopping

                grad_x = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=False)
                grad_eta = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=True)
                cos_x_eta = utils.avg_cos_np(grad_x, grad_eta)

                time_elapsed = time.time() - start_time
                train_str = '[train] loss {:.3f}, reg {:.3f}, acc {:.2%} acc_pgd {:.2%}'.format(train_loss, train_reg, train_acc, train_acc_pgd)
                test_str = '[test] acc_clean {:.2%}, acc_fgsm {:.2%}, acc_pgd {:.2%}, cos_x_eta {:.3}, cos_fgsm_pgd {:.3}'.format(
                    test_acc_clean, test_acc_fgsm, test_acc_pgd, cos_x_eta, cos_fgsm_pgd)
                logger.info('{}-{}: {}  {} ({:.2f}m, {:.2f}m)'.format(epoch, iteration, train_str, test_str,
                                                                      time_train/60, time_elapsed/60))

                if train_acc_pgd > train_acc_pgd_best:  # catastrophic overfitting can be detected on the training set
                    best_state_dict = copy.deepcopy(model.state_dict())
                    train_acc_pgd_best, best_iteration = train_acc_pgd, iteration

                utils.model_train(model, half_prec)
                train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0

            iteration += 1
            X_prev, y_prev = X.clone(), y.clone()  # needed for Free-AT

        if epoch == args.epochs:
            torch.save({'last': model.state_dict(), 'best': best_state_dict}, 'models/{} epoch={}.pth'.format(model_name, epoch))
            # disable global conversion to fp16 from amp.initialize() (https://github.com/NVIDIA/apex/issues/567)
            context_manager = amp.disable_casts() if half_prec else utils.nullcontext()
            with context_manager:
                last_state_dict = copy.deepcopy(model.state_dict())
                half_prec = False  # final eval is always in fp32
                model.load_state_dict(last_state_dict)
                utils.model_eval(model, half_prec)
                opt = torch.optim.SGD(model.parameters(), lr=0)

                attack_iters, n_restarts = (50, 10) if not args.debug else (10, 3)
                test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                logger.info('[last: test on 10k points] acc_clean {:.2%}, pgd_rr {:.2%}'.format(test_acc_clean, test_acc_pgd_rr))

                if args.eval_early_stopped_model:
                    model.load_state_dict(best_state_dict)
                    utils.model_eval(model, half_prec)
                    test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                    test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                    logger.info('[best: test on 10k points][iter={}] acc_clean {:.2%}, pgd_rr {:.2%}'.format(
                        best_iteration, test_acc_clean, test_acc_pgd_rr))

        utils.model_train(model, half_prec)

    logger.info('Done in {:.2f}m'.format((time.time() - start_time) / 60))
Пример #18
0
def main():

    args = parse_args_train(sys.argv[1:])
    args_init = copy.deepcopy(args)
    args_teacher = copy.deepcopy(args)

    # Load a conf file
    if args.resume:
        conf = load_config(os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k != 'resume':
                setattr(args, k, v)
    recog_params = vars(args)

    args = compute_susampling_factor(args)

    # Load dataset
    batch_size = args.batch_size * args.n_gpus if args.n_gpus >= 1 else args.batch_size
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        tsv_path_sub1=args.train_set_sub1,
                        tsv_path_sub2=args.train_set_sub2,
                        dict_path=args.dict,
                        dict_path_sub1=args.dict_sub1,
                        dict_path_sub2=args.dict_sub2,
                        nlsyms=args.nlsyms,
                        unit=args.unit,
                        unit_sub1=args.unit_sub1,
                        unit_sub2=args.unit_sub2,
                        wp_model=args.wp_model,
                        wp_model_sub1=args.wp_model_sub1,
                        wp_model_sub2=args.wp_model_sub2,
                        batch_size=batch_size,
                        n_epochs=args.n_epochs,
                        min_n_frames=args.min_n_frames,
                        max_n_frames=args.max_n_frames,
                        shuffle_bucket=args.shuffle_bucket,
                        sort_by='input',
                        short2long=args.sort_short2long,
                        sort_stop_epoch=args.sort_stop_epoch,
                        dynamic_batching=args.dynamic_batching,
                        ctc=args.ctc_weight > 0,
                        ctc_sub1=args.ctc_weight_sub1 > 0,
                        ctc_sub2=args.ctc_weight_sub2 > 0,
                        subsample_factor=args.subsample_factor,
                        subsample_factor_sub1=args.subsample_factor_sub1,
                        subsample_factor_sub2=args.subsample_factor_sub2,
                        discourse_aware=args.discourse_aware)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      tsv_path_sub1=args.dev_set_sub1,
                      tsv_path_sub2=args.dev_set_sub2,
                      dict_path=args.dict,
                      dict_path_sub1=args.dict_sub1,
                      dict_path_sub2=args.dict_sub2,
                      nlsyms=args.nlsyms,
                      unit=args.unit,
                      unit_sub1=args.unit_sub1,
                      unit_sub2=args.unit_sub2,
                      wp_model=args.wp_model,
                      wp_model_sub1=args.wp_model_sub1,
                      wp_model_sub2=args.wp_model_sub2,
                      batch_size=batch_size,
                      min_n_frames=args.min_n_frames,
                      max_n_frames=args.max_n_frames,
                      ctc=args.ctc_weight > 0,
                      ctc_sub1=args.ctc_weight_sub1 > 0,
                      ctc_sub2=args.ctc_weight_sub2 > 0,
                      subsample_factor=args.subsample_factor,
                      subsample_factor_sub1=args.subsample_factor_sub1,
                      subsample_factor_sub2=args.subsample_factor_sub2)
    eval_sets = [Dataset(corpus=args.corpus,
                         tsv_path=s,
                         dict_path=args.dict,
                         nlsyms=args.nlsyms,
                         unit=args.unit,
                         wp_model=args.wp_model,
                         batch_size=1,
                         is_test=True) for s in args.eval_sets]

    args.vocab = train_set.vocab
    args.vocab_sub1 = train_set.vocab_sub1
    args.vocab_sub2 = train_set.vocab_sub2
    args.input_dim = train_set.input_dim

    # Set save path
    if args.resume:
        save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(save_path)
    else:
        dir_name = set_asr_model_name(args)
        if args.mbr_training:
            assert args.asr_init
            save_path = mkdir_join(os.path.dirname(args.asr_init), dir_name)
        else:
            save_path = mkdir_join(args.model_save_dir, '_'.join(
                os.path.basename(args.train_set).split('.')[:-1]), dir_name)
        save_path = set_save_path(save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(save_path, 'train.log'), stdout=args.stdout)

    # Load a LM conf file for LM fusion & LM initialization
    if not args.resume and args.external_lm:
        lm_conf = load_config(os.path.join(os.path.dirname(args.external_lm), 'conf.yml'))
        args.lm_conf = argparse.Namespace()
        for k, v in lm_conf.items():
            setattr(args.lm_conf, k, v)
        assert args.unit == args.lm_conf.unit
        assert args.vocab == args.lm_conf.vocab

    # Model setting
    model = Speech2Text(args, save_path, train_set.idx2token[0])

    if not args.resume:
        # Save the conf file as a yaml file
        save_config(vars(args), os.path.join(save_path, 'conf.yml'))
        if args.external_lm:
            save_config(args.lm_conf, os.path.join(save_path, 'conf_lm.yml'))

        # Save the nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(save_path, 'nlsyms.txt'))
        for sub in ['', '_sub1', '_sub2']:
            if getattr(args, 'dict' + sub):
                shutil.copy(getattr(args, 'dict' + sub), os.path.join(save_path, 'dict' + sub + '.txt'))
            if getattr(args, 'unit' + sub) == 'wp':
                shutil.copy(getattr(args, 'wp_model' + sub), os.path.join(save_path, 'wp' + sub + '.model'))

        for k, v in sorted(vars(args).items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000))
        logger.info(model)

        # Initialize with pre-trained model's parameters
        if args.asr_init:
            # Load the ASR model (full model)
            conf_init = load_config(os.path.join(os.path.dirname(args.asr_init), 'conf.yml'))
            for k, v in conf_init.items():
                setattr(args_init, k, v)
            model_init = Speech2Text(args_init)
            load_checkpoint(args.asr_init, model_init)

            # Overwrite parameters
            param_dict = dict(model_init.named_parameters())
            for n, p in model.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if args.asr_init_enc_only and 'enc' not in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)

    # Set optimizer
    resume_epoch = 0
    if args.resume:
        resume_epoch = int(args.resume.split('-')[-1])
        optimizer = set_optimizer(model, 'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer,
                                  args.lr, args.weight_decay)
    else:
        optimizer = set_optimizer(model, args.optimizer, args.lr, args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = 'former' in args.enc_type or 'former' in args.dec_type
    optimizer = LRScheduler(optimizer, args.lr,
                            decay_type=args.lr_decay_type,
                            decay_start_epoch=args.lr_decay_start_epoch,
                            decay_rate=args.lr_decay_rate,
                            decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
                            early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
                            lower_better=args.metric not in ['accuracy', 'bleu'],
                            warmup_start_lr=args.warmup_start_lr,
                            warmup_n_steps=args.warmup_n_steps,
                            model_size=getattr(args, 'transformer_d_model', 0),
                            factor=args.lr_factor,
                            noam=args.optimizer == 'noam',
                            save_checkpoints_topk=10 if is_transformer else 1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, optimizer)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            optimizer.convert_to_sgd(model, args.lr, args.weight_decay,
                                     decay_type='always', decay_rate=0.5)

    # Load the teacher ASR model
    teacher = None
    if args.teacher:
        assert os.path.isfile(args.teacher), 'There is no checkpoint.'
        conf_teacher = load_config(os.path.join(os.path.dirname(args.teacher), 'conf.yml'))
        for k, v in conf_teacher.items():
            setattr(args_teacher, k, v)
        # Setting for knowledge distillation
        args_teacher.ss_prob = 0
        args.lsm_prob = 0
        teacher = Speech2Text(args_teacher)
        load_checkpoint(args.teacher, teacher)

    # Load the teacher LM
    teacher_lm = None
    if args.teacher_lm:
        assert os.path.isfile(args.teacher_lm), 'There is no checkpoint.'
        conf_lm = load_config(os.path.join(os.path.dirname(args.teacher_lm), 'conf.yml'))
        args_lm = argparse.Namespace()
        for k, v in conf_lm.items():
            setattr(args_lm, k, v)
        teacher_lm = build_lm(args_lm)
        load_checkpoint(args.teacher_lm, teacher_lm)

    # GPU setting
    use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp = None
    if args.n_gpus >= 1:
        model.cudnn_setting(deterministic=not (is_transformer or args.cudnn_benchmark),
                            benchmark=not is_transformer and args.cudnn_benchmark)
        model.cuda()

        # Mix precision training setting
        if use_apex:
            from apex import amp
            model, optimizer.optimizer = amp.initialize(model, optimizer.optimizer,
                                                        opt_level=args.train_dtype)
            from neural_sp.models.seq2seq.decoders.ctc import CTC
            amp.register_float_function(CTC, "loss_fn")
            # NOTE: see https://github.com/espnet/espnet/pull/1779
            amp.init()
            if args.resume:
                load_checkpoint(args.resume, amp=amp)
        model = CustomDataParallel(model, device_ids=list(range(0, args.n_gpus)))

        if teacher is not None:
            teacher.cuda()
        if teacher_lm is not None:
            teacher_lm.cuda()
    else:
        model = CPUWrapperASR(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(save_path)

    if args.mtl_per_batch:
        # NOTE: from easier to harder tasks
        tasks = []
        if 1 - args.bwd_weight - args.ctc_weight - args.sub1_weight - args.sub2_weight > 0:
            tasks += ['ys']
        if args.bwd_weight > 0:
            tasks = ['ys.bwd'] + tasks
        if args.ctc_weight > 0:
            tasks = ['ys.ctc'] + tasks
        if args.mbr_ce_weight > 0:
            tasks = ['ys.mbr'] + tasks
        for sub in ['sub1', 'sub2']:
            if getattr(args, 'train_set_' + sub):
                if getattr(args, sub + '_weight') - getattr(args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub] + tasks
                if getattr(args, 'ctc_weight_' + sub) > 0:
                    tasks = ['ys_' + sub + '.ctc'] + tasks
    else:
        tasks = ['all']

    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    accum_n_steps = 0
    n_steps = optimizer.n_steps * args.accum_grad_n_steps
    epoch_detail_prev = 0
    for ep in range(resume_epoch, args.n_epochs):
        pbar_epoch = tqdm(total=len(train_set))
        session_prev = None

        for batch_train, is_new_epoch in train_set:
            # Compute loss in the training set
            if args.discourse_aware and batch_train['sessions'][0] != session_prev:
                model.module.reset_session()
            session_prev = batch_train['sessions'][0]
            accum_n_steps += 1

            # Change mini-batch depending on task
            if accum_n_steps == 1:
                loss_train = 0  # moving average over gradient accumulation
            for task in tasks:
                loss, observation = model(batch_train, task,
                                          teacher=teacher, teacher_lm=teacher_lm)
                reporter.add(observation)
                if use_apex:
                    with amp.scale_loss(loss, optimizer.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                loss.detach()  # Trancate the graph
                loss_train = (loss_train * (accum_n_steps - 1) + loss.item()) / accum_n_steps
                if accum_n_steps >= args.accum_grad_n_steps or is_new_epoch:
                    if args.clip_grad_norm > 0:
                        total_norm = torch.nn.utils.clip_grad_norm_(
                            model.module.parameters(), args.clip_grad_norm)
                        reporter.add_tensorboard_scalar('total_norm', total_norm)
                    optimizer.step()
                    optimizer.zero_grad()
                    accum_n_steps = 0
                    # NOTE: parameters are forcibly updated at the end of every epoch
                del loss

            pbar_epoch.update(len(batch_train['utt_ids']))
            reporter.add_tensorboard_scalar('learning_rate', optimizer.lr)
            # NOTE: loss/acc/ppl are already added in the model
            reporter.step()
            n_steps += 1
            # NOTE: n_steps is different from the step counter in Noam Optimizer

            if n_steps % args.print_step == 0:
                # Compute loss in the dev set
                batch_dev = iter(dev_set).next(batch_size=1 if 'transducer' in args.dec_type else None)[0]
                # Change mini-batch depending on task
                for task in tasks:
                    loss, observation = model(batch_dev, task, is_eval=True)
                    reporter.add(observation, is_eval=True)
                    loss_dev = loss.item()
                    del loss
                reporter.step(is_eval=True)

                duration_step = time.time() - start_time_step
                if args.input_type == 'speech':
                    xlen = max(len(x) for x in batch_train['xs'])
                    ylen = max(len(y) for y in batch_train['ys'])
                elif args.input_type == 'text':
                    xlen = max(len(x) for x in batch_train['ys'])
                    ylen = max(len(y) for y in batch_train['ys_sub1'])
                logger.info("step:%d(ep:%.2f) loss:%.3f(%.3f)/lr:%.7f/bs:%d/xlen:%d/ylen:%d (%.2f min)" %
                            (n_steps, optimizer.n_epochs + train_set.epoch_detail,
                             loss_train, loss_dev,
                             optimizer.lr, len(batch_train['utt_ids']),
                             xlen, ylen, duration_step / 60))
                start_time_step = time.time()

            # Save fugures of loss and accuracy
            if n_steps % (args.print_step * 10) == 0:
                reporter.snapshot()
                model.module.plot_attention()
                model.module.plot_ctc()

            # Ealuate model every 0.1 epoch during MBR training
            if args.mbr_training:
                if int(train_set.epoch_detail * 10) != int(epoch_detail_prev * 10):
                    # dev
                    evaluate([model.module], dev_set, recog_params, args,
                             int(train_set.epoch_detail * 10) / 10, logger)
                    # Save the model
                    optimizer.save_checkpoint(
                        model, save_path, remove_old=False, amp=amp,
                        epoch_detail=train_set.epoch_detail)
                epoch_detail_prev = train_set.epoch_detail

            if is_new_epoch:
                break

        # Save checkpoint and evaluate model per epoch
        duration_epoch = time.time() - start_time_epoch
        logger.info('========== EPOCH:%d (%.2f min) ==========' %
                    (optimizer.n_epochs + 1, duration_epoch / 60))

        if optimizer.n_epochs + 1 < args.eval_start_epoch:
            optimizer.epoch()  # lr decay
            reporter.epoch()  # plot

            # Save the model
            optimizer.save_checkpoint(
                model, save_path, remove_old=not is_transformer and args.remove_old_checkpoints, amp=amp)
        else:
            start_time_eval = time.time()
            # dev
            metric_dev = evaluate([model.module], dev_set, recog_params, args,
                                  optimizer.n_epochs + 1, logger)
            optimizer.epoch(metric_dev)  # lr decay
            reporter.epoch(metric_dev, name=args.metric)  # plot

            if optimizer.is_topk or is_transformer:
                # Save the model
                optimizer.save_checkpoint(
                    model, save_path, remove_old=not is_transformer and args.remove_old_checkpoints, amp=amp)

                # test
                if optimizer.is_topk:
                    for eval_set in eval_sets:
                        evaluate([model.module], eval_set, recog_params, args,
                                 optimizer.n_epochs, logger)

            duration_eval = time.time() - start_time_eval
            logger.info('Evaluation time: %.2f min' % (duration_eval / 60))

            # Early stopping
            if optimizer.is_early_stop:
                break

            # Convert to fine-tuning stage
            if optimizer.n_epochs == args.convert_to_sgd_epoch:
                optimizer.convert_to_sgd(model, args.lr, args.weight_decay,
                                         decay_type='always', decay_rate=0.5)

            if optimizer.n_epochs >= args.n_epochs:
                break
            # if args.ss_prob > 0:
            #     model.module.scheduled_sampling_trigger()

            start_time_step = time.time()
            start_time_epoch = time.time()

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.2f hour' % (duration_train / 3600))

    reporter.tf_writer.close()
    pbar_epoch.close()

    return save_path
def main():
    cfg = Config()
    args = parse()
    local_rank = args.local_rank

    torch.cuda.set_device(local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    device = torch.device('cuda:{}'.format(local_rank))

    net = CSP(cfg).to(device)
    center = loss_cls().to(device)
    height = loss_reg().to(device)
    offset = loss_offset().to(device)

    optimizer = optim.Adam(net.parameters(), lr=cfg.init_lr)
    amp.register_float_function(torch, 'sigmoid')
    net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

    if args.resume:

        def resume():
            if os.path.isfile(args.resume):
                checkpoint = torch.load(args.resume, map_location='cpu')
                args.start_epoch = checkpoint['epoch']
                net.load_state_dict(checkpoint['model'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                amp.load_state_dict(checkpoint['amp'])
                if local_rank == 0:
                    print("=>loading checkpoint'{}'".format(args.resume))
                    print("=>loaded checkpoint '{}'(epoch {})".format(
                        args.resume, checkpoint['epoch']))
            else:
                print("=>no checkpoint found at '{}'".format(args.resume))

        resume()
    else:
        args.start_epoch = 0

    if cfg.teacher:
        teacher_dict = net.state_dict()
    else:
        teacher_dict = None

    net = DDP(net)

    # dataset
    gpus = eval(os.environ['CUDA_VISIBLE_DEVICES'])
    if isinstance(gpus, int):
        num_gpus = 1
    else:
        num_gpus = len(gpus)
    batchsize = cfg.onegpu
    args.epoch_length = int(cfg.iter_per_epoch / (num_gpus * batchsize))
    traindataset = CityPersons(path=cfg.root_path, type='train', config=cfg)
    datasampler = DistributedSampler(dataset=traindataset)
    trainloader = DataLoader(traindataset,
                             sampler=datasampler,
                             batch_size=batchsize,
                             shuffle=False,
                             num_workers=8)

    if cfg.val and local_rank == 0:
        testdataset = CityPersons(path=cfg.root_path, type='val', config=cfg)
        testloader = DataLoader(testdataset, batch_size=1, num_workers=4)
    cfg.ckpt_path = args.work_dir
    cfg.gpu_nums = num_gpus
    if local_rank == 0:
        cfg.print_conf()
        print('Training start')
        if not os.path.exists(cfg.ckpt_path):
            os.mkdir(cfg.ckpt_path)
        # open log file
        time_date = datetime.datetime.now()
        time_log = '{}{}{}_{}{}'.format(time_date.year, time_date.month,
                                        time_date.day, time_date.hour,
                                        time_date.minute)
        log_file = os.path.join(cfg.ckpt_path, time_log + '.log')
        log = open(log_file, 'w')
        cfg.write_conf(log)

    if cfg.add_epoch != 0:
        cfg.num_epochs = args.start_epoch + cfg.add_epoch

    args.iter_num = args.epoch_length * cfg.num_epochs

    args.best_loss = np.Inf
    args.best_loss_epoch = 0
    args.best_mr = 100
    args.best_mr_epoch = 0

    if args.resume and cfg.add_epoch == 0:
        args.iter_cur = args.start_epoch * args.epoch_length
    else:
        args.iter_cur = 0

    for epoch in range(args.start_epoch, cfg.num_epochs):
        datasampler.set_epoch(epoch)
        if local_rank == 0:
            print('----------')
            print('Epoch %d begin' % ((epoch + 1)))
        epoch_loss = train(trainloader,
                           net,
                           criterion,
                           center,
                           height,
                           offset,
                           optimizer,
                           epoch,
                           cfg,
                           args,
                           local_rank,
                           teacher_dict=teacher_dict)
        if local_rank == 0:
            if cfg.val and (epoch + 1) >= cfg.val_begin and (
                    epoch + 1) % cfg.val_frequency == 0:
                cur_mr = val(testloader,
                             net,
                             cfg,
                             args,
                             teacher_dict=teacher_dict)
                if cur_mr[0] < args.best_mr:
                    args.best_mr = cur_mr[0]
                    args.best_mr_epoch = epoch + 1
                    print('Epoch %d has lowest MR: %.7f' %
                          (args.best_mr_epoch, args.best_mr))
                    log.write(
                        'epoch_num: %d loss: %.7f Summerize: [Reasonable: %.2f%%], [Reasonable_small: %.2f%%], [Reasonable_occ=heavy: %.2f%%], [All: %.2f%%], lr: %.6f\n'
                        % (epoch + 1, epoch_loss, cur_mr[0] * 100, cur_mr[1] *
                           100, cur_mr[2] * 100, cur_mr[3] * 100, args.lr))
                else:
                    print('Epoch %d has lowest MR: %.7f' %
                          (args.best_mr_epoch, args.best_mr))
                    log.write(
                        'epoch_num: %d loss: %.7f Summerize: [Reasonable: %.2f%%], [Reasonable_small: %.2f%%], [Reasonable_occ=heavy: %.2f%%], [All: %.2f%%], lr: %.6f\n'
                        % (epoch + 1, epoch_loss, cur_mr[0] * 100, cur_mr[1] *
                           100, cur_mr[2] * 100, cur_mr[3] * 100, args.lr))
            if epoch + 1 >= cfg.val_begin - 1:
                print('Save checkpoint...')
                filename = cfg.ckpt_path + '/%s-%d.pth' % (
                    net.module.__class__.__name__, epoch + 1)
                checkpoint = {
                    'epoch': epoch + 1,
                    'optimizer': optimizer.state_dict(),
                    'amp': amp.state_dict()
                }
                if cfg.teacher:
                    checkpoint['model'] = teacher_dict
                else:
                    checkpoint['model'] = net.module.state_dict()
                torch.save(checkpoint, filename)
                # if cfg.teacher:
                #     torch.save(teacher_dict, filename+'.tea')
                print('%s saved.' % filename)
    if local_rank == 0:
        log.write('Epoch %d has lowest MR: %.7f' %
                  (args.best_mr_epoch, args.best_mr))
        log.close()
        print('End of training!')
Пример #20
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)
    if args.num_encs > 1:
        args = format_mulenc_args(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim_list = [
        int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs)
    ]
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
    for i in range(args.num_encs):
        logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i]))
    logging.info("#output dims: " + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = "ctc"
        logging.info("Pure CTC mode")
    elif args.mtlalpha == 0.0:
        mtl_mode = "att"
        logging.info("Pure attention mode")
    else:
        mtl_mode = "mtl"
        logging.info("Multitask learning mode")

    if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1:
        model = load_trained_modules(idim_list[0], odim, args)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(
            idim_list[0] if args.num_encs == 1 else idim_list, odim, args
        )
    assert isinstance(model, ASRInterface)

    logging.info(
        " Total parameter of the model = "
        + str(sum(p.numel() for p in model.parameters()))
    )

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)
        )
        torch_load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps(
                (idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8")
        )
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)"
                % (args.batch_size, args.batch_size * args.ngpu)
            )
            args.batch_size *= args.ngpu
        if args.num_encs > 1:
            # TODO(ruizhili): implement data parallel for multi-encoder setup.
            raise NotImplementedError(
                "Data parallel is not supported for multi-encoder setup."
            )

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(
            model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay
        )
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model.parameters(), args.adim, args.transformer_warmup_steps, args.transformer_lr
        )
    elif args.opt == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=0.0008, alpha=0.95)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux"
            )
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype
            )
        else:
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=args.train_dtype
            )
        use_apex = True

        from espnet.nets.pytorch_backend.ctc import CTC

        amp.register_float_function(CTC, "loss_fn")
        amp.init()
        logging.warning("register ctc as float function")
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    if args.num_encs == 1:
        converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)
    else:
        converter = CustomConverterMulEnc(
            [i[0] for i in model.subsample_list], dtype=dtype
        )

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {"main": train_iter},
        optimizer,
        device,
        args.ngpu,
        args.grad_noise,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)
        )

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0 and "transformer" in args.model_module:
        data = sorted(
            list(valid_json.items())[: args.num_save_attention],
            key=lambda x: int(x[1]["input"][0]["shape"][1]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
        )
        trainer.extend(att_reporter, trigger=(1, "epoch"))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    if args.num_encs > 1:
        report_keys_loss_ctc = [
                                   "main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)
                               ] + ["validation/main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)]
        report_keys_cer_ctc = [
                                  "main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)
                              ] + ["validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)]
    trainer.extend(
        extensions.PlotReport(
            [
                "main/loss",
                "validation/main/loss",
                "main/loss_ctc",
                "validation/main/loss_ctc",
                "main/loss_att",
                "validation/main/loss_att",
            ]
            + ([] if args.num_encs == 1 else report_keys_loss_ctc),
            "epoch",
            file_name="loss.png",
        )
    )
    trainer.extend(
        extensions.PlotReport(
            ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
        )
    )
    trainer.extend(
        extensions.PlotReport(
            ["main/cer_ctc", "validation/main/cer_ctc"]
            + ([] if args.num_encs == 1 else report_keys_loss_ctc),
            "epoch",
            file_name="cer.png",
        )
    )

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    if mtl_mode != "ctc":
        trainer.extend(
            snapshot_object(model, "model.acc.best"),
            trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
        )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # epsilon decay in the optimizer
    if args.opt == "adadelta":
        if args.criterion == "acc" and mtl_mode != "ctc":
            trainer.extend(
                restore_snapshot(
                    model, args.outdir + "/model.acc.best", load_fn=torch_load
                ),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value > current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value > current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(
                    model, args.outdir + "/model.loss.best", load_fn=torch_load
                ),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value < current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value < current_value,
                ),
            )

    # lr decay in rmsprop
    if args.opt == "rmsprop":
        if args.criterion == "acc" and mtl_mode != "ctc":
            trainer.extend(
                restore_snapshot(
                    model, args.outdir + "/model.acc.best", load_fn=torch_load
                ),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value > current_value,
                ),
            )
            trainer.extend(
                rmsprop_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value > current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(
                    model, args.outdir + "/model.loss.best", load_fn=torch_load
                ),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value < current_value,
                ),
            )
            trainer.extend(
                rmsprop_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value < current_value,
                ),
            )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
    )
    report_keys = [
                      "epoch",
                      "iteration",
                      "main/loss",
                      "main/loss_ctc",
                      "main/loss_att",
                      "validation/main/loss",
                      "validation/main/loss_ctc",
                      "validation/main/loss_att",
                      "main/acc",
                      "validation/main/acc",
                      "main/cer_ctc",
                      "validation/main/cer_ctc",
                      "elapsed_time",
                  ] + ([] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc)
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
                    "eps"
                ],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    if args.opt == "rmsprop":
        trainer.extend(
            extensions.observe_value(
                "lr",
                lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
                    "lr"
                ],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("lr")
    if args.report_cer:
        report_keys.append("validation/main/cer")
    if args.report_wer:
        report_keys.append("validation/main/wer")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )

    trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Пример #21
0
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
torch.backends.cudnn.benchmark = True
from model import Glow
import multiprocessing as mp
import torch
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
import torch.utils.data.distributed
import torch.distributed as dist
#from torch.nn.parallel import DistributedDataParallel as DDP 
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp
from apex.parallel import Reducer
amp.register_float_function(torch,"inverse")
amp_handle = amp.init(enabled=True)


parser = argparse.ArgumentParser(description='Glow trainer')
parser.add_argument('--batch', default=6, type=int, help='batch size')
parser.add_argument('--iter', default=20000000, type=int, help='maximum iterations')
parser.add_argument(
    '--n_flow', default=32, type=int, help='number of flows in each block'
)
parser.add_argument('--n_block', default=7, type=int, help='number of blocks')
parser.add_argument(
    '--no_lu',
    action='store_true',
    help='use plain convolution instead of LU decomposed version',
)
Пример #22
0
    def _initModel(
        self,
        optimizerCreator: Optional[OptimizerCreator] = None,
        forceDevice: str = None,
        modelTrain: bool = False
    ) -> Tuple[Optional[torch.optim.Optimizer], torch.device]:
        """
        Initialization method that gathers common operations that are done at start of train, eval or inference phase.

        :param optimizerCreator: Optimizer creator that will create optimizer that should be used for training.
            Send None when training is not your case.
        :type optimizerCreator: Optional[OptimizerCreator]
        :param forceDevice: Name of device that should be forced to torch. Default is none which means that it uses cuda
            if can and cpu otherwise.
        :type forceDevice: str
        :param modelTrain: Flag that determines mode of model. True -> Train. False -> eval
        :type modelTrain: bool
        :return: Returns  optimizer (in case of provided creator) and device where the
            model is.
        :rtype: Tuple[Optional[torch.optim.Optimizer], torch.device]
        :raise ImportError: When Nvidia apex couldn't be imported and therefore fp16 precision can not be used.
                    Screams only when you ask for fp16 precision.
        :raise AttributeError: This model is in mixed precision and can be used only with cuda device.
        """

        if forceDevice is None:
            device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")
        else:
            device = torch.device(forceDevice)

        self.model = self.model.to(device)

        # Let's create user selected optimizer
        optimizer = None if optimizerCreator is None else optimizerCreator.create(
            self.model)

        if device.type != "cuda" and (self._shouldActivateMixedPrecision
                                      or self._mixedPrecisionActivated):
            raise AttributeError(
                "This model is in mixed precision and can be used only with cuda device."
            )

        if device.type == "cuda" and self._shouldActivateMixedPrecision:
            try:
                from apex import amp
                amp.register_float_function(torch, 'sigmoid')
                if not self._mixedPrecisionActivated:
                    self.model, optimizer = amp.initialize(
                        self.model,
                        optimizer,
                        opt_level=self.FP16_MIX_PREC_OPT_LVL.value)
                    self._mixedPrecisionActivated = True
                    self._shouldActivateMixedPrecision = False
            except ImportError:
                raise ImportError(
                    "You must install apex (https://www.github.com/nvidia/apex) for fp16 training."
                )

        if modelTrain:
            self.model.train()
        else:
            self.model.eval()

        return optimizer, device
Пример #23
0
def train(gpu, config, shared_dict, barrier, train_ds, val_ds, backbone):
    # --- Set seeds --- #
    torch.manual_seed(
        2
    )  # For DistributedDataParallel: make sure all models are initialized identically
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
    # os.environ['CUDA_LAUNCH_BLOCKING'] = 1
    torch.autograd.set_detect_anomaly(True)

    # --- Setup DistributedDataParallel --- #
    rank = config["nr"] * config["gpus"] + gpu
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='env://',
                                         world_size=config["world_size"],
                                         rank=rank)

    if gpu == 0:
        print("# --- Start training --- #")

    # --- Setup run --- #
    # Setup run on process 0:
    if gpu == 0:
        shared_dict["run_dirpath"], shared_dict[
            "init_checkpoints_dirpath"] = local_utils.setup_run(config)
    barrier.wait(
    )  # Wait on all processes so that shared_dict is synchronized.

    # Choose device
    torch.cuda.set_device(gpu)

    # --- Online transform performed on the device (GPU):
    train_online_cuda_transform = data_transforms.get_online_cuda_transform(
        config, augmentations=config["data_aug_params"]["enable"])
    if val_ds is not None:
        eval_online_cuda_transform = data_transforms.get_online_cuda_transform(
            config, augmentations=False)
    else:
        eval_online_cuda_transform = None

    if "samples" in config:
        rng_samples = random.Random(0)
        train_ds = torch.utils.data.Subset(
            train_ds,
            rng_samples.sample(range(len(train_ds)), config["samples"]))
        if val_ds is not None:
            val_ds = torch.utils.data.Subset(
                val_ds,
                rng_samples.sample(range(len(val_ds)), config["samples"]))
        # test_ds = torch.utils.data.Subset(test_ds, list(range(config["samples"])))

    if gpu == 0:
        print(f"Train dataset has {len(train_ds)} samples.")

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_ds, num_replicas=config["world_size"], rank=rank)
    val_sampler = None
    if val_ds is not None:
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_ds, num_replicas=config["world_size"], rank=rank)
    if "samples" in config:
        eval_batch_size = min(2 * config["optim_params"]["batch_size"],
                              config["samples"])
    else:
        eval_batch_size = 2 * config["optim_params"]["batch_size"]

    init_dl = torch.utils.data.DataLoader(train_ds,
                                          batch_size=eval_batch_size,
                                          pin_memory=True,
                                          sampler=train_sampler,
                                          num_workers=config["num_workers"],
                                          drop_last=True)
    train_dl = torch.utils.data.DataLoader(
        train_ds,
        batch_size=config["optim_params"]["batch_size"],
        shuffle=False,
        pin_memory=True,
        sampler=train_sampler,
        num_workers=config["num_workers"],
        drop_last=True)
    if val_ds is not None:
        val_dl = torch.utils.data.DataLoader(val_ds,
                                             batch_size=eval_batch_size,
                                             pin_memory=True,
                                             sampler=val_sampler,
                                             num_workers=config["num_workers"],
                                             drop_last=True)
    else:
        val_dl = None

    model = FrameFieldModel(config,
                            backbone=backbone,
                            train_transform=train_online_cuda_transform,
                            eval_transform=eval_online_cuda_transform)
    model.cuda(gpu)
    if gpu == 0:
        print("Model has {} trainable params".format(
            count_trainable_params(model)))

    loss_func = losses.build_combined_loss(config).cuda(gpu)
    # Compute learning rate
    lr = min(
        config["optim_params"]["base_lr"] *
        config["optim_params"]["batch_size"] * config["world_size"],
        config["optim_params"]["max_lr"])

    if config["optim_params"]["optimizer"] == "Adam":
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=lr,
            # weight_decay=config["optim_params"]["weight_decay"],
            eps=1e-8  # Increase if instability is detected
        )
    elif config["optim_params"]["optimizer"] == "RMSProp":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
    else:
        raise NotImplementedError(
            f"Optimizer {config['optim_params']['optimizer']} not recognized")
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    if config["use_amp"] and APEX_AVAILABLE:
        amp.register_float_function(torch, 'sigmoid')
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    elif config["use_amp"] and not APEX_AVAILABLE and gpu == 0:
        print_utils.print_warning(
            "WARNING: Cannot use amp because the apex library is not available!"
        )

    # Wrap the model for distributed training
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[gpu], find_unused_parameters=True)

    # def lr_warmup_func(epoch):
    #     if epoch < config["warmup_epochs"]:
    #         coef = 1 + (config["warmup_factor"] - 1) * (config["warmup_epochs"] - epoch) / config["warmup_epochs"]
    #     else:
    #         coef = 1
    #     return coef
    # lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_warmup_func)
    # lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer, config["optim_params"]["gamma"])

    trainer = Trainer(
        rank,
        gpu,
        config,
        model,
        optimizer,
        loss_func,
        run_dirpath=shared_dict["run_dirpath"],
        init_checkpoints_dirpath=shared_dict["init_checkpoints_dirpath"],
        lr_scheduler=lr_scheduler)
    trainer.fit(train_dl, val_dl=val_dl, init_dl=init_dl)
Пример #24
0
    def __init__(
            self,
            config

    ):
        self.config = config
        pprint.pprint(config.__dict__)

        if config.backbone == 'resnet18':
            self.model = resnet_face18()
        elif config.backbone == 'resnet34':
            self.model = resnet_face34()
        elif config.backbone == 'resnet50':
            self.model = se_resnet50_ir()

        if config.pretrained:
            checkpoint = torch.load(config.pretrained, map_location="cpu")
            self.model.load_state_dict(checkpoint)
            print("Load pretrained model!")

        self.train_counter = 0
        if not os.path.exists(self.config.log_dir):
            os.mkdir(self.config.log_dir)

        if not os.path.exists(self.config.weight_dir):
            os.mkdir(self.config.weight_dir)

        logging.basicConfig(level=logging.DEBUG,  # 控制台打印的日志级别
                            filename=os.path.join(self.config.log_dir, 'Trainer_{}.log'.format(
                                datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
                            )),
                            filemode='w',
                            format=
                            '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
                            )

        self.train_dh = ImageFolder(config.train_datasets,
                                    transform=trans,
                                    target_transform=target_trans,
                                    is_valid_file=lambda x: x.endswith(('.jpg', '.png', 'jpeg'))
                                    )
        # print("ImageFolder.__class__", self.train_dh.class_to_idx)
        assert len(self.train_dh.classes) == config.num_classes
        self.train_loader = Data.DataLoader(
            self.train_dh,
            batch_size=self.config.batch_size,
            num_workers=config.num_workers,
            pin_memory=True,
            shuffle=True
        )
        print("Load datasets!")
        self.lfw_list = get_lfw_list(config.lfw_test_list)
        self.lfw_paths = [os.path.join(config.lfw_root, each) for each in self.lfw_list]

        self.logger = logging.getLogger(__name__)
        self.writer = SummaryWriter(logdir=self.config.log_dir)

        if config.loss == 'focal_loss':
            self.criterion = FocalLoss(gamma=2)
        else:
            self.criterion = torch.nn.CrossEntropyLoss()

        if config.metric == 'add_margin':
            self.metric_fc = AddMarginProduct(512, config.num_classes, s=30, m=0.35)
        elif config.metric == 'arc_margin':
            self.metric_fc = ArcMarginProduct(512, config.num_classes, s=64, m=0.5, easy_margin=config.easy_margin, fp16=config.fp16)
        elif config.metric == 'sphere':
            self.metric_fc = SphereProduct(512, config.num_classes, m=4)
        else:
            self.metric_fc = nn.Linear(512, config.num_classes)

        if config.optimizer == 'sgd':
            # self.optimizer = torch.optim.SGD(
            #     [{'params': self.model.parameters()}, {'params': self.metric_fc.parameters()}],
            #     lr=config.lr, weight_decay=config.weight_decay)
            self.optimizer = torch.optim.SGD(
                self.add_weight_decay(self.model) + self.add_weight_decay(self.metric_fc),
                lr=config.lr, momentum=config.momentum)
        else:
            self.optimizer = torch.optim.Adam(
                [{'params': self.model.parameters()}, {'params': self.metric_fc.parameters()}],
                lr=config.lr, weight_decay=config.weight_decay)

        # lr decay

        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=config.lr_step, gamma=0.1)
        self.scheduler_warmup = GradualWarmupScheduler(
            self.optimizer,
            multiplier=1,
            total_epoch=3,
            after_scheduler=self.scheduler)

        self.cuda = config.cuda and torch.cuda.is_available()

        if self.cuda:
            self.criterion.cuda()
            self.model.cuda()
            self.metric_fc.cuda()
        if not self.config.fp16 and config.num_gpu > 1:
            self.model = torch.nn.DataParallel(self.model).cuda()
            self.metric_fc = torch.nn.DataParallel(self.metric_fc).cuda()

        if self.config.fp16:
            amp.register_float_function(utils, 'ArcMarginProduct')
            [self.model, self.metric_fc], self.optimizer = amp.initialize([self.model, self.metric_fc],
                                                                          self.optimizer,
                                                                          opt_level=config.opt_level
                                                                          )
            if config.num_gpu > 1:
                self.model = DDP(self.model)
Пример #25
0
def train():
    """
    """

    args = parser.parse_args()

    if args.start_fold == 0:
        start_fold = 1
    else:
        start_fold = args.start_fold

    for k in range(start_fold, args.fold + 1):
        print("========== Fold {} ==========".format(k))
        # ---- setting model & optimizer ----
        model = selec_model(args.arch)

        if args.optm == "SGD":
            optimizer = optim.SGD(model.parameters(),
                                  lr=args.lr,
                                  momentum=0.9,
                                  weight_decay=0.0005)
            # optimizer = optim.SGD(model.parameters(), lr=args.lr)
        elif args.optm == "Adam":
            optimizer = optim.Adam(model.parameters(),
                                   lr=args.lr,
                                   betas=(0.5, 0.999))
        scheduler = ReduceLROnPlateau(optimizer,
                                      'min',
                                      factor=0.5,
                                      patience=100,
                                      verbose=True,
                                      eps=1e-6)

        if args.apex:
            amp.register_float_function(torch, 'sigmoid')
            amp.register_float_function(F, 'softmax')
            model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

        model = torch.nn.DataParallel(model)
        if args.start_fold != 0:
            model_ckpt = torch.load(r"./checkpoints/fold_{}.pth.tar".format(k))
            model.load_state_dict(model_ckpt)
        model.train()

        # ---- Tensorboard ----
        writer = SummaryWriter()

        # ---- start training ----
        path_txt_train = r"{}/fold_{}/train.txt".format(args.txt, k)
        loss_min = 100
        n_plateau = 0
        for epoch in range(1, args.epoch + 1):
            # ---- timer ----
            starttime = datetime.datetime.now()
            # ---- loading data ----
            print("Loading data...")
            dataset = ImageNetDataset(args.data, path_txt_train, "train")
            data = DataLoader(dataset=dataset,
                              batch_size=args.bs,
                              shuffle=True,
                              num_workers=12,
                              pin_memory=True)
            # ---- loop for all train data ----
            loss_train_sum = 0
            n_sum = 0
            n_error = 0
            for step, (inputs, labels) in enumerate(tqdm(data)):
                # ---- inputs & labels ----
                inputs = inputs.to(device, dtype=torch.float)
                labels = labels.to(device,
                                   dtype=torch.long)  # [batch, 1000] & long
                # print("Input shape:{}".format(inputs.shape))
                # print("Label shape:{}".format(labels.shape))
                # ---- fp ----
                preds = model(inputs)  # [batch, 1000]
                labels_max = torch.max(labels, 1)[1]
                loss = criterion(preds, labels_max)
                loss_train_sum += loss.item()
                if args.apex:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                # ---- top 5 error ----
                n_sum_i, n_error_i = cal_top5_error(preds, labels)
                n_sum += n_sum_i
                n_error += n_error_i
            # ---- train loss ----
            loss_train = loss_train_sum / len(data)
            # ---- train error ----
            top5_error_train = round(n_error / n_sum, 3)
            # ---- validation ----
            loss_val, top5_error_val = validation(model, k, args)
            scheduler.step(loss_val)

            # ---- saving ckpt: minimum val loss ----
            if loss_val < loss_min:
                loss_min = loss_val
                print("Best model saved at epoch {}!".format(epoch))
                torch.save(model.state_dict(),
                           r"./checkpoints/fold_{}.pth.tar".format(k))
                n_plateau = 0
            else:
                n_plateau += 1
            # ---- timer ----
            endtime = datetime.datetime.now()
            elapsed = (endtime - starttime).seconds
            # ---- printing ----
            print(
                "fold #{}, epoch #{}, train: {:.4f}, val: {:.4f}, top5_train:{}, top5_val:{}, elapsed: {}s"
                .format(k, epoch, loss_train, loss_val, top5_error_train,
                        top5_error_val, elapsed))
            print('-' * 60)
            if n_plateau >= 20:
                break
Пример #26
0
def setApex(model_g, model_d, optimG, optimD):
    amp.register_float_function(torch, 'sigmoid')
    amp.register_float_function(F, 'softmax')
    model_g, optimG = amp.initialize(model_g, optimG, opt_level="O1")
    model_d, optimD = amp.initialize(model_d, optimD, opt_level="O1")
    return model_g, model_d, optimG, optimD
Пример #27
0
def main():
    args = parse_args()
    global local_rank
    local_rank = args.local_rank
    if local_rank == 0:
        global logger
        logger = get_logger(__name__, args.log)

    torch.cuda.empty_cache()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    torch.cuda.set_device(local_rank)
    dist.init_process_group(backend='nccl', init_method='env://')
    global gpus_num
    gpus_num = torch.cuda.device_count()
    if local_rank == 0:
        logger.info(f'use {gpus_num} gpus')
        logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()

    # dataset and dataloader
    if local_rank == 0:
        logger.info('start loading data')
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        Config.train_dataset, shuffle=True, rank=local_rank)
    
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.per_node_batch_size,
                              shuffle=False,
                              pin_memory=True,
                              drop_last=True,
                              num_workers=args.num_workers,
                              sampler=train_sampler)
    if local_rank == 0:
        logger.info('finish loading data')

    model = centernet.__dict__[args.network](**{
        "pretrained": args.pretrained,
        "num_classes": args.num_classes,
        "multi_head": args.multi_head,
        "selayer": args.selayer,
        "use_ttf": args.use_ttf,
        "cls_mlp": args.cls_mlp
    })
    
    if args.multi_head:
        pre_model = torch.load(args.pre_model_dir, map_location='cpu')
        if local_rank == 0:
            logger.info(f"pretrained_model: {args.pre_model_dir}")
        
        if args.load_head:
            def copyStateDict(state_dict):
                if list(state_dict.keys())[0].startswith('module'):
                    start_idx = 1
                else:
                    start_idx = 0
                new_state_dict = OrderedDict()
                for k,v in state_dict.items():
                    name = '.'.join(k.split('.')[start_idx:])

                    new_state_dict[name] = v
                return new_state_dict
            new_dict=copyStateDict(pre_model)

            keys=[]
            keys2 = []
            for k,v in new_dict.items():
                keys.append(k)
#                 if k.startswith('centernet_head.heatmap_head.0'):
#                     continue
#                 else:
#                     keys2.append(k)
                keys2.append(k)
            final_dict = {k:new_dict[k] for k in keys}
            for item in keys2:
                temp_name = copy.deepcopy(item)
                final_dict[temp_name.replace('centernet_head', 'centernet_head_2')] = new_dict[item]
            
            model.load_state_dict({k:new_dict[k] for k in keys}, strict = False)
            
        else:
            model.load_state_dict(pre_model, strict=False)
        
        for p in model.backbone.parameters():
            p.requires_grad = False
        for p in model.centernet_head.parameters():
            p.requires_grad = False
            
    
    if local_rank == 0:
        for name, param in model.named_parameters():
            logger.info(f"{name},{param.requires_grad}")

    flops_input = torch.randn(1, 3, args.input_image_size,
                              args.input_image_size)
    flops, params = profile(model, inputs=(flops_input, ))
    flops, params = clever_format([flops, params], "%.3f")
    if local_rank == 0:
        logger.info(
            f"model: '{args.network}', flops: {flops}, params: {params}")

    criterion = CenterNetLoss(max_object_num=Config.max_object_num).cuda()
    decoder = CenterNetDecoder(image_w=args.input_image_size,
                               image_h=args.input_image_size).cuda()

    model = model.cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=0.1)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
#                                                            patience=3,
#                                                            verbose=True)

    if args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if args.apex:
        amp.register_float_function(torch, 'sigmoid')
        amp.register_float_function(torch, 'softmax')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = apex.parallel.DistributedDataParallel(model,
                                                      delay_allreduce=True)
        if args.sync_bn:
            model = apex.parallel.convert_syncbn_model(model)
    else:
        model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=[local_rank],
                                                    output_device=local_rank)

    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            if local_rank == 0:
                logger.exception(
                    '{} is not a file, please check it again'.format(
                        args.resume))
            sys.exit(-1)
        if local_rank == 0:
            logger.info('start only evaluating')
            logger.info(f"start resuming model from {args.evaluate}")
        checkpoint = torch.load(args.evaluate,
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        if local_rank == 0:
            logger.info(f"start eval.")
            all_eval_result = validate(Config.val_dataset, model, decoder,
                                       args)
            logger.info(f"eval done.")
            if all_eval_result is not None:
                logger.info(
                    f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
                )

        return

    best_map = 0.0
    start_epoch = 1
    # resume training
    if os.path.exists(args.resume):
        if local_rank == 0:
            logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        best_map = checkpoint['best_map']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        if local_rank == 0:
            logger.info(
                f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
                f"loss: {checkpoint['loss']:3f}, heatmap_loss: {checkpoint['heatmap_loss']:2f}, offset_loss: {checkpoint['offset_loss']:2f},wh_loss: {checkpoint['wh_loss']:2f}"
            )

    if local_rank == 0:
        if not os.path.exists(args.checkpoints):
            os.makedirs(args.checkpoints)

    if local_rank == 0:
        logger.info('start training')
    for epoch in range(start_epoch, args.epochs + 1):
        train_sampler.set_epoch(epoch)
        heatmap_losses, offset_losses, wh_losses, losses = train(
            train_loader, model, criterion, optimizer, scheduler, epoch, args)

        if local_rank == 0:
            logger.info(
                f"train: epoch {epoch:0>3d}, heatmap_loss: {heatmap_losses:.2f}, offset_loss: {offset_losses:.2f}, wh_loss: {wh_losses:.2f}, loss: {losses:.2f}"
            )

        if epoch % 10 == 0 or epoch == args.epochs:
            if local_rank == 0:
                logger.info(f"start eval.")
                all_eval_result = validate(Config.val_dataset, model, decoder,
                                           args)
                logger.info(f"eval done.")
                if all_eval_result is not None:
                    logger.info(
                        f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
                    )
                    if all_eval_result[0] > best_map:
                        torch.save(model.module.state_dict(),
                                   os.path.join(args.checkpoints, "best.pth"))
                        best_map = all_eval_result[0]
        if local_rank == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'best_map': best_map,
                    'heatmap_loss': heatmap_losses,
                    'offset_loss': offset_losses,
                    'wh_loss': wh_losses,
                    'loss': losses,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }, os.path.join(args.checkpoints, 'latest.pth'))

    if local_rank == 0:
        logger.info(f"finish training, best_map: {best_map:.3f}")
    training_time = (time.time() - start_time) / 3600
    if local_rank == 0:
        logger.info(
            f"finish training, total training time: {training_time:.2f} hours")
Пример #28
0
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms.functional

try:
    from apex import amp

    amp.register_float_function(torch, 'matmul')
except ImportError:
    raise ImportError(
        "Please install apex from https://www.github.com/nvidia/apex to run this example."
    )


class AdversarialLoss(nn.Module):
    r"""
    Adversarial loss
    https://arxiv.org/abs/1711.10337
    """
    def __init__(self,
                 type='nsgan',
                 target_real_label=1.0,
                 target_fake_label=0.0):
        r"""
        type = nsgan | lsgan | hinge
        """
        super(AdversarialLoss, self).__init__()

        self.type = type
        self.register_buffer('real_label', torch.tensor(target_real_label))
Пример #29
0
def main(logger, args):
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    torch.cuda.empty_cache()

    if args.seed is not None:
        random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True

    gpus = torch.cuda.device_count()
    logger.info(f'use {gpus} gpus')
    logger.info(f"args: {args}")

    cudnn.benchmark = True
    cudnn.enabled = True
    start_time = time.time()

    # dataset and dataloader
    logger.info('start loading data')
    collater = Collater()
    train_loader = DataLoader(Config.train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              collate_fn=collater.next)
    logger.info('finish loading data')

    model = retinanet.__dict__[args.network](**{
        "pretrained": args.pretrained,
        "num_classes": args.num_classes,
    })

    for name, param in model.named_parameters():
        logger.info(f"{name},{param.requires_grad}")

    flops_input = torch.randn(1, 3, args.input_image_size,
                              args.input_image_size)
    flops, params = profile(model, inputs=(flops_input, ))
    flops, params = clever_format([flops, params], "%.3f")
    logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")

    criterion = RetinaLoss(image_w=args.input_image_size,
                           image_h=args.input_image_size).cuda()
    decoder = RetinaDecoder(image_w=args.input_image_size,
                            image_h=args.input_image_size).cuda()

    model = model.cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    if args.apex:
        amp.register_float_function(torch, 'sigmoid')
        amp.register_float_function(torch, 'softmax')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    model = nn.DataParallel(model)

    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            raise Exception(
                f"{args.resume} is not a file, please check it again")
        logger.info('start only evaluating')
        logger.info(f"start resuming model from {args.evaluate}")
        checkpoint = torch.load(args.evaluate,
                                map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        logger.info(f"start eval.")
        all_eval_result = validate(Config.val_dataset, model, decoder, args)
        logger.info(f"eval done.")
        if all_eval_result is not None:
            logger.info(
                f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
            )

        return

    best_map = 0.0
    start_epoch = 1
    # resume training
    if os.path.exists(args.resume):
        logger.info(f"start resuming model from {args.resume}")
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        start_epoch += checkpoint['epoch']
        best_map = checkpoint['best_map']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        logger.info(
            f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
            f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}"
        )

    if not os.path.exists(args.checkpoints):
        os.makedirs(args.checkpoints)

    logger.info('start training')
    for epoch in range(start_epoch, args.epochs + 1):
        cls_losses, reg_losses, losses = train(train_loader, model, criterion,
                                               optimizer, scheduler, epoch,
                                               logger, args)
        logger.info(
            f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, loss: {losses:.2f}"
        )

        if epoch % 1 == 0 or epoch == args.epochs:
            logger.info(f"start eval.")
            all_eval_result = validate(Config.val_dataset, model, decoder,
                                       args)
            logger.info(f"eval done.")
            if all_eval_result is not None:
                logger.info(
                    f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
                )
                if all_eval_result[0] > best_map:
                    torch.save(model.module.state_dict(),
                               os.path.join(args.checkpoints, "best.pth"))
                    best_map = all_eval_result[0]
        torch.save(
            {
                'epoch': epoch,
                'best_map': best_map,
                'cls_loss': cls_losses,
                'reg_loss': reg_losses,
                'loss': losses,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, os.path.join(args.checkpoints, 'latest.pth'))

    logger.info(f"finish training, best_map: {best_map:.3f}")
    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, total training time: {training_time:.2f} hours")
Пример #30
0
    def train(self):
        self.best_epoch_raw_ssim_loss = -1e19

        had_training_error = False

        if self.args.use_amp:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this program.")
            print('enabling AMP...')
            opt_level = "O1"
            amp.register_float_function(torch, 'batch_norm')
            if self.args.use_gan:
                self.discriminator, self.disc_optimizer = amp.initialize(self.discriminator, self.disc_optimizer, opt_level=opt_level, loss_scale=1.0)
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=opt_level, loss_scale=1.0)
            # disable use of ssim loss for training due to issues with mixed precision
            self.args.w_gen_ssim = 0.0

        if self.args.use_data_parallel:
            print('enabling data parallel...')
            self.model = nn.DataParallel(self.model)
            self.model = self.model.to(self.device)
            self.loss_function = nn.DataParallel(self.loss_function)
            self.loss_function = self.loss_function.to(self.device)
            if self.args.use_gan:
                self.discriminator = nn.DataParallel(self.discriminator)
                self.discriminator = self.discriminator.to(self.device)
                self.gan_criterion = nn.DataParallel(self.gan_criterion)
                self.gan_criterion = self.gan_criterion.to(self.device)

                if 0.0 < self.args.w_gen_seg3d or self.args.use_seg3d_proxy:
                    self.seg_criterion = nn.DataParallel(self.seg_criterion)
                    self.seg_criterion = self.seg_criterion.to(self.device)

        for epoch in range(self.args.epochs):
            if had_training_error:
                print('exiting early due to training error')
                break

            self.adjust_learning_rate(self.optimizer, epoch,
                                      self.args.learning_rate)
            if self.args.use_gan:
                self.adjust_learning_rate(self.disc_optimizer, epoch,
                                          self.args.disc_learning_rate)

            start_time = time.time()
            loss_sum = 0.0
            disc_loss_sum = 0.0

            for (i, data) in enumerate(self.train_loader, 0):
                self.model.train()
                if self.args.use_gan and 0.0 < self.args.w_disc_gan_label:
                    self.discriminator.train()

                self.batch_num += 1

                if ((0) != (self.args.test_interval)) and ((0) == ((self.batch_num) % (self.args.test_interval))):
                    self.run_test_batch(use_file_tuples=False)

                if (not self.args.use_variable_num_views) or (self.batch_num % self.args.log_interval == 0):
                    num_inputs_to_use = self.args.num_combine_views
                else:
                    sample_prob = np.random.uniform(0, 1, 1)[0]
                    if sample_prob < 0.500:
                        num_inputs_to_use = 1
                    elif sample_prob < 0.750:
                        num_inputs_to_use = 2
                    elif sample_prob < 0.875:
                        num_inputs_to_use = 3
                    else:
                        num_inputs_to_use = 4

                data = self.get_data(data, num_inputs_to_use)

                if 0 == (i + 1) % self.args.log_interval:
                    print('start ' + str(i + 1) + ' of ' + str(self.n_img / self.args.batch_size))

                model_out = self.model(num_inputs_to_use, data)

                if self.args.use_gan and 0.0 < self.args.w_disc_gan_label:
                    # reset training params
                    self.disc_optimizer.zero_grad()

                    disc_loss = self.compute_disc_losses(model_out, data, loss_type='train')

                    # update model
                    if self.args.use_amp:
                        with amp.scale_loss(disc_loss, self.disc_optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        disc_loss.backward()
                    self.disc_optimizer.step()

                # reset training params
                self.optimizer.zero_grad()

                loss = self.compute_gen_losses(model_out, data, loss_type='train')

                # update model
                if self.args.use_amp:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                self.optimizer.step()

                loss_item = loss.item()
                if self.args.use_gan and 0.0 < self.args.w_disc_gan_label:
                    disc_loss_item = disc_loss.item()
                else:
                    disc_loss_item = 0.0

                # check for NAN during training
                if loss_item != loss_item or disc_loss_item != disc_loss_item:
                    print('NAN loss in training:', loss_item, disc_loss_item)
                    had_training_error = True
                    exit(-1)

                loss_sum += loss_item

                if self.batch_num % self.args.log_interval == 0:
                    log_string = "Batch %d" % self.batch_num
                    for k, v in self.logs.items():
                        if 'l_eval_' == k[0:7]:
                            scale_factor = 1.0
                        elif 'l_test_' == k[0:7]:
                            scale_factor = (float(self.args.log_interval) / self.args.test_interval) if 0 != self.args.test_interval else 1.0
                        else:
                            scale_factor = float(self.args.log_interval)
                        log_string += " [%s] %5.3f" % (k, v / scale_factor)

                    log_string += ". Took %5.2f" % (time.time() - start_time)

                    print(log_string)

                    for tag, value in self.logs.items():
                        if 'l_eval_' == tag[0:7]:
                            scale_factor = 1.0
                        elif 'l_test_' == tag[0:7]:
                            scale_factor = (float(self.args.log_interval) / self.args.test_interval) if 0 != self.args.test_interval else 1.0
                        else:
                            scale_factor = float(self.args.log_interval)
                        self.logger.scalar_summary(tag, value / scale_factor, self.batch_num)

                    self.reset_logs('train')
                    self.reset_logs('test')
                    self.log_images(model_out, data, 'T_Images')

                if 0 == (i + 1) % self.args.log_interval:
                    crnt_time = time.time()
                    print('end ' + str(i + 1) + ' of ' + str(self.n_img / self.args.batch_size))
                    print(
                        'time:',
                        round(crnt_time - start_time, 3),
                        's',
                        loss_item,
                        loss_sum / self.args.log_interval,
                    )
                    start_time = crnt_time
                    loss_sum = 0.0

                # save regularly after processing the specified number of input images
                if 0 == self.batch_num % self.args.int_save_interval:
                    model_name = self.args.model_path[:-4] + '_int_cpt.pth'
                    self.save(model_name)

                if 0 == self.batch_num % self.args.checkpoint_save_interval:
                    model_name = self.args.model_path[:-4] + '_batch_' + str(int(self.batch_num / 1000.0)) + 'k_cpt.pth'
                    self.save(model_name)
                    print('Checkpoint model saved to ' + model_name)

            if 0 == (epoch + 1) % self.args.epoch_save_interval:
                model_name = self.args.model_path[:-4] + '_epoch_' + str(epoch) + '_cpt.pth'
                self.save(model_name)
                print('Epoch model saved to ' + model_name)

            if self.do_run_eval:
                epoch_test_loss, epoch_raw_ssim_loss = self.run_eval()
                print('Epoch testing loss: ' + str(epoch) + ' ' + str(epoch_test_loss) + ' ' + str(epoch_raw_ssim_loss))
                if epoch_raw_ssim_loss > self.best_epoch_raw_ssim_loss:
                    print('Best raw ssim: ' + str(1.0 - epoch_raw_ssim_loss) + ' ' + str(epoch_raw_ssim_loss))
                    self.best_epoch_raw_ssim_loss = epoch_raw_ssim_loss
                    model_name = self.args.model_path[:-4] + '_best.pth'
                    self.save(model_name)

        print('Finished Training. Best loss: ', self.best_epoch_raw_ssim_loss)
        self.save(self.args.model_path)