Exemple #1
0
class DistributedTrainer(LearnerCallback):
    _order = -20 # Needs to run before the recorder
    def __init__(self, learn:Learner, cuda_id:int=0):
        super().__init__(learn)
        self.cuda_id,self.train_sampler = cuda_id,None

    def on_train_begin(self, **kwargs):
        self.learn.model = DistributedDataParallel(self.learn.model, device_ids=[self.cuda_id],
                                                   output_device=self.cuda_id)
        self.train_sampler = DistributedSampler(self.learn.data.train_dl.dataset)
        self.learn.data.train_dl = self.learn.data.train_dl.new(shuffle=False, sampler=self.train_sampler)
        self.learn.data.train_dl.add_tfm(make_async)
        if hasattr(self.learn.data, 'valid_dl') and self.learn.data.valid_dl is not None:
            self.valid_sampler = DistributedSampler(self.learn.data.valid_dl.dataset)
            self.learn.data.valid_dl = self.learn.data.valid_dl.new(shuffle=False, sampler=self.valid_sampler)
            self.learn.data.valid_dl.add_tfm(make_async)
        self.rank = rank_distrib()
        self.learn.recorder.silent = (self.rank != 0)

    def on_epoch_begin(self, epoch, **kwargs): self.train_sampler.set_epoch(epoch)

    def on_train_end(self, **kwargs):
        self.learn.model = self.learn.model.module
        self.learn.data.train_dl.remove_tfm(make_async)
        if hasattr(self.learn.data, 'valid_dl') and self.learn.data.valid_dl is not None:
            self.learn.data.valid_dl.remove_tfm(make_async)
Exemple #2
0
def make_dataloader(cfg, distributed=False, mode="train"):
    if mode == "train":
        dataset = SaliencyDataset(cfg.dataset,
                                  prior=cfg.prior,
                                  transform=train_transform)
        if cfg.gpu_id == 0:
            print("=> [{}] Dataset: {} - Prior: {} | {} images.".format(
                mode.upper(), cfg.dataset, cfg.prior, len(dataset)))
    else:
        dataset = SaliencyDataset(cfg.test_dataset,
                                  prior=cfg.test_prior,
                                  transform=test_transform)
        if cfg.gpu_id == 0:
            print("=> [{}] Dataset: {} - Prior: {} | {} images.".format(
                mode.upper(), cfg.test_dataset, cfg.test_prior, len(dataset)))

    data_sampler = DistributedSampler(dataset) if distributed else None
    data_loader_ = data.DataLoader(
        dataset,
        batch_size=cfg.batch_size if mode == "train" else 1,
        num_workers=cfg.num_workers,
        sampler=data_sampler,
        collate_fn=SaliencyDataset.collate_fn,
    )
    if mode == "train":
        data_loader = Dataloader(data_loader_, distributed)
    else:
        data_loader = data_loader_
        if distributed:
            data_sampler.set_epoch(0)

    return data_loader
Exemple #3
0
class MultiprocessingDataloader(SqliteDataLoader):

    #------------------------------------
    # Constructor
    #-------------------

    def __init__(self, dataset, world_size, node_rank, **kwargs):

        self.dataset = dataset

        self.sampler = DistributedSampler(dataset,
                                          num_replicas=world_size,
                                          rank=node_rank)

        super().__init__(dataset,
                         shuffle=False,
                         num_workers=0,
                         pin_memory=True,
                         sampler=self.sampler,
                         **kwargs)

    #------------------------------------
    # set_epoch
    #-------------------

    def set_epoch(self, epoch):
        self.sampler.set_epoch(epoch)
def train(model):
    step = 0
    nr_eval = 0
    dataset = VimeoDataset('train', loaddata=True)
    sampler = DistributedSampler(dataset)
    train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler)
    args.step_per_epoch = train_data.__len__()
    dataset_val = VimeoDataset('validation', loaddata=True)
    val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=8)
    evaluate(model, val_data, nr_eval)
    model.save_model(log_path, local_rank)
    model.load_model(log_path, local_rank)
    print('training...')
    time_stamp = time.time()
    for epoch in range(args.epoch):
        sampler.set_epoch(epoch)
        for i, data in enumerate(train_data):
            data_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            data_gpu, flow_gt = data
            data_gpu = data_gpu.to(device, non_blocking=True) / 255.
            flow_gt = flow_gt.to(device, non_blocking=True)
            imgs = data_gpu[:, :6]
            gt = data_gpu[:, 6:9]
            mul = np.cos(step / (args.epoch * args.step_per_epoch) * math.pi) * 0.5 + 0.5
            learning_rate = get_learning_rate(step)
            pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, learning_rate, mul, True, flow_gt)
            train_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            #if step % 100 == 1 and local_rank == 0:
            writer.add_scalar('learning_rate', learning_rate, step) 
            writer.add_scalar('loss_l1', loss_l1, step)
            writer.add_scalar('loss_flow', loss_flow, step)
            writer.add_scalar('loss_cons', loss_cons, step)
            writer.add_scalar('loss_ter', loss_ter, step)
            if step % 1000 == 1 and local_rank == 0:
                gt = (gt.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                pred = (pred.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                merged_img = (merged_img.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                flow = flow.permute(0, 2, 3, 1).detach().cpu().numpy()
                flow_mask = flow_mask.permute(0, 2, 3, 1).detach().cpu().numpy()
                flow_gt = flow_gt.permute(0, 2, 3, 1).detach().cpu().numpy()
                for i in range(5):
                    imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
                    writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC')
                    writer.add_image(str(i) + '/flow', flow2rgb(flow[i]), step, dataformats='HWC')
                    writer.add_image(str(i) + '/flow_gt', flow2rgb(flow_gt[i]), step, dataformats='HWC')
                    writer.add_image(str(i) + '/flow_mask', flow2rgb(flow[i] * flow_mask[i]), step, dataformats='HWC')
                writer.flush()
            if local_rank == 0:
                print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, loss_l1))
            step += 1
        nr_eval += 1
        if nr_eval % 5 == 0:
            evaluate(model, val_data, step)
        model.save_model(log_path, local_rank)    
        dist.barrier()
Exemple #5
0
 def _get_sampler(self, epoch):
     world_size = get_world_size()
     rank = get_rank()
     sampler = DistributedSampler(self,
                                  num_replicas=world_size,
                                  rank=rank,
                                  shuffle=self.shuffle)
     sampler.set_epoch(epoch)
     return sampler
Exemple #6
0
def training(cfg, world_size, model, dataset_train, dataset_validation):
    rank = cfg.local_rank
    batch_size = cfg.batch_size
    epochs = cfg.epoch
    lr = cfg.learning_rate
    random_seed = cfg.random_seed

    sampler = DistributedSampler(dataset=dataset_train,
                                 num_replicas=world_size,
                                 rank=rank,
                                 shuffle=True,
                                 seed=random_seed,
                                 drop_last=True)

    train_loader = DataLoader(dataset=dataset_train,
                              batch_size=batch_size,
                              pin_memory=True,
                              sampler=sampler)

    val_loader = DataLoader(dataset=dataset_validation,
                            batch_size=batch_size,
                            pin_memory=True)

    criterion = torch.nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    try:
        for epoch in range(epochs):
            # In distributed mode, calling the set_epoch() method at the beginning of each epoch
            # before creating the DataLoader iterator is necessary to make shuffling work properly
            # across multiple epochs. Otherwise, the same ordering will be always used.
            sampler.set_epoch(epoch)

            train_one_epoch(model=model,
                            train_loader=train_loader,
                            optimizer=opt,
                            criterion=criterion,
                            rank=rank,
                            world_size=world_size,
                            epoch=epoch,
                            num_epoch=epochs)

            validation(model=model,
                       data_loader=val_loader,
                       criterion=criterion,
                       rank=rank,
                       world_size=world_size,
                       epoch=epoch,
                       num_epoch=epochs)

    except KeyboardInterrupt:
        pass

    return model
Exemple #7
0
def test_dataset_loader():
    from tqdm import tqdm
    from torch.utils.data import DataLoader
    from torch.utils.data.distributed import DistributedSampler
    from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess
    from jukebox.hparams import setup_hparams
    from jukebox.data.files_dataset import FilesAudioDataset
    hps = setup_hparams("teeny", {})
    hps.sr = 22050  # 44100
    hps.hop_length = 512
    hps.labels = False
    hps.channels = 2
    hps.aug_shift = False
    hps.bs = 2
    hps.nworkers = 2  # Getting 20 it/s with 2 workers, 10 it/s with 1 worker
    print(hps)
    dataset = hps.dataset
    root = hps.root
    from tensorboardX import SummaryWriter
    sr = {22050: '22k', 44100: '44k', 48000: '48k'}[hps.sr]
    writer = SummaryWriter(f'{root}/{dataset}/logs/{sr}/logs')
    dataset = FilesAudioDataset(hps)
    print("Length of dataset", len(dataset))

    # Torch Loader
    collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], 0)
    sampler = DistributedSampler(dataset)
    train_loader = DataLoader(dataset,
                              batch_size=hps.bs,
                              num_workers=hps.nworkers,
                              pin_memory=False,
                              sampler=sampler,
                              drop_last=True,
                              collate_fn=collate_fn)

    dist.barrier()
    sampler.set_epoch(0)
    for i, x in enumerate(tqdm(train_loader)):
        x = x.to('cuda', non_blocking=True)
        for j, aud in enumerate(x):
            writer.add_audio('in_' + str(i * hps.bs + j), aud, 1, hps.sr)
        print("Wrote in")
        x = audio_preprocess(x, hps)
        x = audio_postprocess(x, hps)
        for j, aud in enumerate(x):
            writer.add_audio('out_' + str(i * hps.bs + j), aud, 1, hps.sr)
        print("Wrote out")
        dist.barrier()
        break
    def prepare_dist_data_loader(self, dataset, batch_size, epoch=0):
        # prepare distributed data loader
        data_sampler = DistributedSampler(dataset)
        data_sampler.set_epoch(epoch)

        if self.custom_collate_fn is None:
            dataloader = DataLoader(dataset,
                                    batch_size=batch_size,
                                    sampler=data_sampler)
        else:
            dataloader = DataLoader(dataset,
                                    batch_size=batch_size,
                                    sampler=data_sampler,
                                    collate_fn=self.custom_collate_fn)
        return dataloader
Exemple #9
0
def get_data_loader(dataset, collate_func, batch_size, distributed, epoch):
    if distributed:
        from torch.utils.data.distributed import DistributedSampler
        data_sampler = DistributedSampler(dataset)
        data_sampler.set_epoch(epoch)
        loader = torch.utils.data.DataLoader(dataset,
                                             collate_fn=collate_func,
                                             batch_size=batch_size,
                                             sampler=data_sampler)
    else:
        loader = torch.utils.data.DataLoader(dataset,
                                             collate_fn=collate_func,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4)
    return loader
def training(rank):
    if rank == 0:
        global logger
        logger = get_logger(__name__, "train.log")
    setup(rank)

    dataset = ToyDataset()
    sampler = DistributedSampler(dataset,
                                 num_replicas=dist.get_world_size(),
                                 rank=rank)
    data_loader = DataLoader(dataset,
                             batch_size=2,
                             shuffle=False,
                             sampler=sampler)

    model = ToyModel()
    optimizer = optim.SGD(model.parameters(),
                          lr=1e-2,
                          momentum=0.9,
                          weight_decay=1e-4)
    model = model.cuda()
    model = DistributedDataParallel(model,
                                    device_ids=[rank],
                                    output_device=rank)

    loss_fn = nn.MSELoss()
    epochs = 200

    for epoch in range(epochs):
        sampler.set_epoch(epoch)
        for ite, (inputs, labels) in enumerate(data_loader):
            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model.forward(inputs)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if dist.get_rank() == 0:
            logger.info(
                "pid:%s rank:%s epoch:%s loss:%s batch_size:%s" %
                (os.getpid(), rank, epoch, loss.item(), inputs.shape[0]))

        if epoch == epochs - 1 and dist.get_rank() == 0:
            torch.save(model.state_dict(), "toy.pth")
Exemple #11
0
def main_worker(gpu, ngpus_per_node, args, num_jobs):
    args.gpu = gpu

    args.rank = args.rank * ngpus_per_node + gpu
    print(f"  Use GPU: local[{args.gpu}] | global[{args.rank}]")
    dist.init_process_group(
        backend=args.dist_backend, init_method=args.dist_url,
        world_size=args.world_size, rank=args.rank)

    world_size = dist.get_world_size()
    local_writers = [open(f"{args.output_dir}/decode.{i+1}.ark", "wb")
                     for i in range(args.rank, num_jobs, world_size)]

    inferset = InferDataset(args.input_scp)
    res = len(inferset) % args.world_size
    if res > 0:
        inferset.dataset = inferset.dataset[:-res]

    dist_sampler = DistributedSampler(inferset)
    dist_sampler.set_epoch(1)

    testloader = DataLoader(
        inferset, batch_size=1, shuffle=(dist_sampler is None),
        num_workers=args.workers, pin_memory=True,
        sampler=dist_sampler)

    with open(args.config, 'r') as fi:
        configures = json.load(fi)

    model = build_model(args, configures, train=False)

    torch.cuda.set_device(args.gpu)
    model.cuda(args.gpu)
    model.load_state_dict(torch.load(
        args.resume, map_location=f"cuda:{args.gpu}"))
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.gpu])
    model.eval()

    if args.rank == 0:
        print("> Model built.")
        print("  Model size:{:.2f}M".format(
            utils.count_parameters(model)/1e6))

    cal_logit(model, testloader, args.gpu, local_writers)
Exemple #12
0
    def _get_sampler(self, epoch: int):
        """
        Return a :class:`torch.utils.data.sampler.Sampler` to sample the data.

        This is used to distribute the data across the replicas. If shuffling
        is enabled, every epoch will have a different shuffle.

        Args:
            epoch: The epoch being fetched.

        Returns:
            A sampler which tells the data loader which sample to load next.
        """
        world_size = get_world_size()
        rank = get_rank()
        sampler = DistributedSampler(self,
                                     num_replicas=world_size,
                                     rank=rank,
                                     shuffle=self.shuffle)
        sampler.set_epoch(epoch)
        return sampler
Exemple #13
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    print('Enabled distributed training.')

    rank, world_size = init_dist(
        backend='nccl', port=args.port)
    args.rank = rank
    args.world_size = world_size

    np.random.seed(args.seed*args.rank)
    torch.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed_all(args.seed*args.rank)
    print('random seed: ', args.seed*args.rank)

    # create model
    print("=> creating model '{}'".format(args.model))
    if args.SinglePath:
        architecture = 20*[0]
        channels_scales = 20*[1.0]
        model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales)
        model.cuda()
        broadcast_params(model)
        for v in model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        model.log_alpha.grad = torch.zeros_like(model.log_alpha)   
    
    criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda()


    wo_wd_params = []
    wo_wd_param_names = []
    network_params = []
    network_param_names = []

    for name, mod in model.named_modules():
        if isinstance(mod, nn.BatchNorm2d):
            for key, value in mod.named_parameters():
                wo_wd_param_names.append(name+'.'+key)
        
    for key, value in model.named_parameters():
        if key != 'log_alpha':
            if value.requires_grad:
                if key in wo_wd_param_names:
                    wo_wd_params.append(value)
                else:
                    network_params.append(value)
                    network_param_names.append(key)

    params = [
        {'params': network_params,
         'lr': args.base_lr,
         'weight_decay': args.weight_decay },
        {'params': wo_wd_params,
         'lr': args.base_lr,
         'weight_decay': 0.},
    ]
    param_names = [network_param_names, wo_wd_param_names]
    if args.rank == 0:
        print('>>> params w/o weight decay: ', wo_wd_param_names)

    optimizer = torch.optim.SGD(params, momentum=args.momentum)
    if args.SinglePath:
        arch_optimizer = torch.optim.Adam(
            [param for name, param in model.named_parameters() if name == 'log_alpha'],
            lr=args.arch_learning_rate,
            betas=(0.5, 0.999),
            weight_decay=args.arch_weight_decay
        )

    # auto resume from a checkpoint
    remark = 'imagenet_'
    remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(args.base_lr)  + '_seed_' + str(args.seed) + '_pretrain_' + str(args.pretrain_epoch)

    if args.early_fix_arch:
        remark += '_early_fix_arch'  

    if args.flops_loss:
        remark += '_flops_loss_' + str(args.flops_loss_coef)

    if args.remark != 'none':
        remark += '_'+args.remark

    args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark)
    args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark)
    generate_date = str(datetime.now().date())

    path = os.path.join(generate_date, args.save)
    if args.rank == 0:
        log_format = '%(asctime)s %(message)s'
        utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py'))
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(path, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", args)
        writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log)
    else:
        writer = None

    model_dir = path
    start_epoch = 0
    
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer)

    cudnn.benchmark = True
    cudnn.enabled = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_dataset_wo_ms = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    val_dataset = ImagenetDataset(
        args.val_root,
        args.val_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    train_loader_wo_ms = DataLoader(
        train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    val_loader = DataLoader(
        val_dataset, batch_size=50, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer, logging)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, 85):
        train_sampler.set_epoch(epoch)
        
        if args.early_fix_arch:
            if len(model.fix_arch_index.keys()) > 0:
                for key, value_lst in model.fix_arch_index.items():
                    model.log_alpha.data[key, :] = value_lst[1]
            sort_log_alpha = torch.topk(F.softmax(model.log_alpha.data, dim=-1), 2)
            argmax_index = (sort_log_alpha[0][:,0] - sort_log_alpha[0][:,1] >= 0.3)
            for id in range(argmax_index.size(0)):
                if argmax_index[id] == 1 and id not in model.fix_arch_index.keys():
                    model.fix_arch_index[id] = [sort_log_alpha[1][id,0].item(), model.log_alpha.detach().clone()[id, :]]
            
        if args.rank == 0 and args.SinglePath:
            logging.info('epoch %d', epoch)
            logging.info(model.log_alpha)         
            logging.info(F.softmax(model.log_alpha, dim=-1))         
            logging.info('flops %fM', model.cal_flops())  

        # train for one epoch
        if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms:
            train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)
        else:
            train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)


        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer, logging)
        if args.gen_max_child:
            args.gen_max_child_flag = True
            prec1 = validate(val_loader, model, criterion, epoch, writer, logging)        
            args.gen_max_child_flag = False

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(model_dir, {
                'epoch': epoch + 1,
                'model': args.model,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)
Exemple #14
0
class GraphDataLoader:
    """PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
    graph and corresponding label tensor (if provided) of the said minibatch.

    Parameters
    ----------
    collate_fn : Function, default is None
        The customized collate function. Will use the default collate
        function if not given.
    use_ddp : boolean, optional
        If True, tells the DataLoader to split the training set for each
        participating process appropriately using
        :class:`torch.utils.data.distributed.DistributedSampler`.

        Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
    ddp_seed : int, optional
        The seed for shuffling the dataset in
        :class:`torch.utils.data.distributed.DistributedSampler`.

        Only effective when :attr:`use_ddp` is True.
    kwargs : dict
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.

    Examples
    --------
    To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
    the backend is PyTorch):

    >>> dataloader = dgl.dataloading.GraphDataLoader(
    ...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for batched_graph, labels in dataloader:
    ...     train_on(batched_graph, labels)

    **Using with Distributed Data Parallel**

    If you are using PyTorch's distributed training (e.g. when using
    :mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
    turning on the :attr:`use_ddp` option:

    >>> dataloader = dgl.dataloading.GraphDataLoader(
    ...     dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for epoch in range(start_epoch, n_epochs):
    ...     dataloader.set_epoch(epoch)
    ...     for batched_graph, labels in dataloader:
    ...         train_on(batched_graph, labels)
    """
    collator_arglist = inspect.getfullargspec(GraphCollator).args

    def __init__(self,
                 dataset,
                 collate_fn=None,
                 use_ddp=False,
                 ddp_seed=0,
                 **kwargs):
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v

        if collate_fn is None:
            self.collate = GraphCollator(**collator_kwargs).collate
        else:
            self.collate = collate_fn

        self.use_ddp = use_ddp
        if use_ddp:
            self.dist_sampler = DistributedSampler(
                dataset,
                shuffle=dataloader_kwargs['shuffle'],
                drop_last=dataloader_kwargs['drop_last'],
                seed=ddp_seed)
            dataloader_kwargs['shuffle'] = False
            dataloader_kwargs['drop_last'] = False
            dataloader_kwargs['sampler'] = self.dist_sampler

        self.dataloader = DataLoader(dataset=dataset,
                                     collate_fn=self.collate,
                                     **dataloader_kwargs)

    def __iter__(self):
        """Return the iterator of the data loader."""
        return iter(self.dataloader)

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)

    def set_epoch(self, epoch):
        """Sets the epoch number for the underlying sampler which ensures all replicas
        to use a different ordering for each epoch.

        Only available when :attr:`use_ddp` is True.

        Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.

        Parameters
        ----------
        epoch : int
            The epoch number.
        """
        if self.use_ddp:
            self.dist_sampler.set_epoch(epoch)
        else:
            raise DGLError('set_epoch is only available when use_ddp is True.')
def main():
    parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
    arg = parser.add_argument
    arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
    arg('--workers', type=int, default=4, help='number of cpu threads to use')
    arg('--gpu',
        type=str,
        default='0',
        help='List of GPUs for parallel training, e.g. 0,1,2,3')
    arg('--output-dir', type=str, default='weights/')
    arg('--resume', type=str, default='')
    arg('--fold', type=int, default=0)
    arg('--prefix', type=str, default='softmax_')
    arg('--data-dir', type=str, default="/home/selim/datasets/xview/train")
    arg('--folds-csv', type=str, default='folds.csv')
    arg('--logdir', type=str, default='logs')
    arg('--zero-score', action='store_true', default=False)
    arg('--from-zero', action='store_true', default=False)
    arg('--distributed', action='store_true', default=False)
    arg('--freeze-epochs', type=int, default=1)
    arg("--local_rank", default=0, type=int)
    arg("--opt-level", default='O1', type=str)
    arg("--predictions", default="../oof_preds", type=str)
    arg("--test_every", type=int, default=1)

    args = parser.parse_args()

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
    else:
        os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.benchmark = True

    conf = load_config(args.config)
    model = models.__dict__[conf['network']](seg_classes=conf['num_classes'],
                                             backbone_arch=conf['encoder'])

    model = model.cuda()
    if args.distributed:
        model = convert_syncbn_model(model)
    damage_loss_function = losses.__dict__[conf["damage_loss"]["type"]](
        **conf["damage_loss"]["params"]).cuda()
    loss_functions = {"damage_loss": damage_loss_function}
    optimizer, scheduler = create_optimizer(conf['optimizer'], model)

    dice_best = 0
    xview_best = 0
    start_epoch = 0
    batch_size = conf['optimizer']['batch_size']

    data_train = XviewSingleDataset(
        mode="train",
        fold=args.fold,
        data_path=args.data_dir,
        folds_csv=args.folds_csv,
        transforms=create_train_transforms(conf['input']),
        multiplier=conf["data_multiplier"],
        normalize=conf["input"].get("normalize", None),
        equibatch=False)
    data_val = XviewSingleDataset(
        mode="val",
        fold=args.fold,
        data_path=args.data_dir,
        folds_csv=args.folds_csv,
        transforms=create_val_transforms(conf['input']),
        normalize=conf["input"].get("normalize", None))
    train_sampler = None
    if args.distributed:
        train_sampler = DistributedSampler(data_train)

    train_data_loader = DataLoader(data_train,
                                   batch_size=batch_size,
                                   num_workers=args.workers,
                                   shuffle=False,
                                   sampler=train_sampler,
                                   pin_memory=False,
                                   drop_last=True)
    val_batch_size = 1
    val_data_loader = DataLoader(data_val,
                                 batch_size=val_batch_size,
                                 num_workers=args.workers,
                                 shuffle=False,
                                 pin_memory=False)

    os.makedirs(args.logdir, exist_ok=True)
    summary_writer = SummaryWriter(args.logdir + '/' + args.prefix +
                                   conf['encoder'])
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            state_dict = checkpoint['state_dict']
            if conf['optimizer'].get('zero_decoder', False):
                for key in state_dict.copy().keys():
                    if key.startswith("module.final"):
                        del state_dict[key]
            state_dict = {k[7:]: w for k, w in state_dict.items()}
            model.load_state_dict(state_dict, strict=False)
            if not args.from_zero:
                start_epoch = checkpoint['epoch']
                if not args.zero_score:
                    dice_best = checkpoint.get('dice_best', 0)
                    xview_best = checkpoint.get('xview_best', 0)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    if args.from_zero:
        start_epoch = 0
    current_epoch = start_epoch

    if conf['fp16']:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale='dynamic')

    snapshot_name = "{}{}_{}_{}".format(args.prefix, conf['network'],
                                        conf['encoder'], args.fold)

    if args.distributed:
        model = DistributedDataParallel(model, delay_allreduce=True)
    else:
        model = DataParallel(model).cuda()
    for epoch in range(start_epoch, conf['optimizer']['schedule']['epochs']):
        if epoch < args.freeze_epochs:
            print("Freezing encoder!!!")
            if hasattr(model.module, 'encoder_stages1'):
                model.module.encoder_stages1.eval()
                model.module.encoder_stages2.eval()
                for p in model.module.encoder_stages1.parameters():
                    p.requires_grad = False
                for p in model.module.encoder_stages2.parameters():
                    p.requires_grad = False
            else:
                model.module.encoder_stages.eval()
                for p in model.module.encoder_stages.parameters():
                    p.requires_grad = False
        else:
            if hasattr(model.module, 'encoder_stages1'):
                print("Unfreezing encoder!!!")
                model.module.encoder_stages1.train()
                model.module.encoder_stages2.train()
                for p in model.module.encoder_stages1.parameters():
                    p.requires_grad = True
                for p in model.module.encoder_stages2.parameters():
                    p.requires_grad = True
            else:
                model.module.encoder_stages.train()
                for p in model.module.encoder_stages.parameters():
                    p.requires_grad = True

        if train_sampler:
            train_sampler.set_epoch(current_epoch)
        train_epoch(current_epoch, loss_functions, model, optimizer, scheduler,
                    train_data_loader, summary_writer, conf, args.local_rank)

        model = model.eval()
        if args.local_rank == 0:
            torch.save(
                {
                    'epoch': current_epoch + 1,
                    'state_dict': model.state_dict(),
                    'dice_best': dice_best,
                    'xview_best': xview_best,
                }, args.output_dir + '/' + snapshot_name + "_last")
            if epoch % args.test_every == 0:
                preds_dir = os.path.join(args.predictions, snapshot_name)
                dice_best, xview_best = evaluate_val(
                    args,
                    val_data_loader,
                    xview_best,
                    dice_best,
                    model,
                    snapshot_name=snapshot_name,
                    current_epoch=current_epoch,
                    optimizer=optimizer,
                    summary_writer=summary_writer,
                    predictions_dir=preds_dir)
        current_epoch += 1
Exemple #16
0
def run():

    #Model
    net = ResNet18()
    net = net.cuda()

    #Data
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = torchvision.datasets.ImageFolder(root=args.dir_data_train,
                                                     transform=transform_train)
    train_sampler = DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=64,
                                               shuffle=(train_sampler is None),
                                               num_workers=2,
                                               pin_memory=True,
                                               sampler=train_sampler)
    test_dataset = torchvision.datasets.ImageFolder(root=args.dir_data_val,
                                                    transform=transform_test)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=128,
                                              shuffle=False,
                                              num_workers=2,
                                              pin_memory=True)

    # Optimizer and scheduler of Training
    optimizer = optim.SGD(net.parameters(), lr=args.lr)

    #Training
    top1 = AverageMeter()
    top5 = AverageMeter()
    logs = []
    print("Training Start")
    for epoch in range(args.epochs):

        print("Training for epoch {}".format(epoch))
        net.train()
        train_sampler.set_epoch(epoch)
        for i, data in enumerate(train_loader, 0):
            batch_start_time = time.time()
            x, label = data
            x, label = Variable(x).cuda(), Variable(label).cuda()

            # optimizer.zero_grad()
            output = net(x)
            loss = criterion(output, label)

            prec1, prec5 = accuracy(output.data, label, topk=(1, 5))
            top1.update(prec1[0], x.size(0))
            top5.update(prec5[0], x.size(0))

            loss.backward()
            paralist = gradient_execute(net)
            optimizer.step()
            for para1, para2 in zip(paralist, net.parameters()):
                para2.grad.data = para1

            log_obj = {
                'timestamp': datetime.now(),
                'iteration': i,
                'training_loss': loss.data.item(),
                'training_accuracy1': top1.avg.item() / 100.0,
                'training_accuracy5': top5.avg.item() / 100.0,
                'total_param': Total_param_num,
                'sparse_param': Sparse_param_num,
                'mini_batch_time': (time.time() - batch_start_time)
            }
            if i % 20 == 0:
                print("Timestamp: {timestamp} | "
                      "Iteration: {iteration:6} | "
                      "Loss: {training_loss:6.4f} | "
                      "Accuracy1: {training_accuracy1:6.4f} | "
                      "Accuracy5: {training_accuracy5:6.4f} | "
                      "Total_param: {total_param:6} | "
                      "Sparse_param: {sparse_param:6} | "
                      "Mini_Batch_Time: {mini_batch_time:6.4f} | ".format(
                          **log_obj))

            logs.append(log_obj)

        if True:
            logs[-1]['test_loss'], logs[-1]['test_accuracy1'], logs[-1][
                'test_accuracy5'] = evaluate(net, test_loader)
            print("Timestamp: {timestamp} | "
                  "Iteration: {iteration:6} | "
                  "Loss: {training_loss:6.4f} | "
                  "Accuracy1: {training_accuracy1:6.4f} | "
                  "Accuracy5: {training_accuracy5:6.4f} | "
                  "Total_param: {total_param:6} | "
                  "Sparse_param: {sparse_param:6} | "
                  "Mini_Batch_Time: {mini_batch_time:6.4f} | "
                  "Test Loss: {test_loss:6.4f} | "
                  "Test Accuracy1: {test_accuracy1:6.4f} | "
                  "Test_Accuracy5: {test_accuracy5:6.4f}".format(**logs[-1]))

    df = pd.DataFrame(logs)
    df.to_csv('./log/{}_Node{}_{}.csv'.format(
        args.file_name, args.dist_rank,
        datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
              index_label='index')
    print("Finished Training")
Exemple #17
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          sigma, loss_empthasis, iters_per_checkpoint, batch_size, seed, fp16_run,
          checkpoint_path, with_tensorboard, logdirname, datedlogdir, warm_start=False, optimizer='ADAM', start_zero=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======
    
    global WaveGlow
    global WaveGlowLoss
    
    ax = True # this is **really** bad coding practice :D
    if ax:
        from efficient_model_ax import WaveGlow
        from efficient_loss import WaveGlowLoss
    else:
        if waveglow_config["yoyo"]: # efficient_mode # TODO: Add to Config File
            from efficient_model import WaveGlow
            from efficient_loss import WaveGlowLoss
        else:
            from glow import WaveGlow, WaveGlowLoss
    
    criterion = WaveGlowLoss(sigma, loss_empthasis)
    model = WaveGlow(**waveglow_config).cuda()
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======
    STFTs = [STFT.TacotronSTFT(filter_length=window,
                                 hop_length=data_config['hop_length'],
                                 win_length=window,
                                 sampling_rate=data_config['sampling_rate'],
                                 n_mel_channels=160,
                                 mel_fmin=0, mel_fmax=16000) for window in data_config['validation_windows']]
    
    loader_STFT = STFT.TacotronSTFT(filter_length=data_config['filter_length'],
                                 hop_length=data_config['hop_length'],
                                 win_length=data_config['win_length'],
                                 sampling_rate=data_config['sampling_rate'],
                                 n_mel_channels=data_config['n_mel_channels'] if 'n_mel_channels' in data_config.keys() else 160,
                                 mel_fmin=data_config['mel_fmin'], mel_fmax=data_config['mel_fmax'])
    
    #optimizer = "Adam"
    optimizer = optimizer.lower()
    optimizer_fused = bool( 0 ) # use Apex fused optimizer, should be identical to normal but slightly faster and only works on RTX cards
    if optimizer_fused:
        from apex import optimizers as apexopt
        if optimizer == "adam":
            optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            optimizer = apexopt.FusedLAMB(model.parameters(), lr=learning_rate, max_grad_norm=200)
    else:
        if optimizer == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            from lamb import Lamb as optLAMB
            optimizer = optLAMB(model.parameters(), lr=learning_rate)
            #import torch_optimizer as optim
            #optimizer = optim.Lamb(model.parameters(), lr=learning_rate)
            #raise# PyTorch doesn't currently include LAMB optimizer.
    
    if fp16_run:
        global amp
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        amp = None
    
    ## LEARNING RATE SCHEDULER
    if True:
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        min_lr = 1e-8
        factor = 0.1**(1/5) # amount to scale the LR by on Validation Loss plateau
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=20, cooldown=2, min_lr=min_lr, verbose=True, threshold=0.0001, threshold_mode='abs')
        print("ReduceLROnPlateau used as Learning Rate Scheduler.")
    else: scheduler=False
    
    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model,
                                                      optimizer, scheduler, fp16_run, warm_start=warm_start)
        iteration += 1  # next iteration is iteration + 1
    if start_zero:
        iteration = 0
    
    trainset = Mel2Samp(**data_config, check_files=True)
    speaker_lookup = trainset.speaker_ids
    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        train_sampler = DistributedSampler(trainset, shuffle=True)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset, num_workers=3, shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)
    
    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)
    
    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter
        if datedlogdir:
            timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
            log_directory = os.path.join(output_directory, logdirname, timestr)
        else:
            log_directory = os.path.join(output_directory, logdirname)
        logger = SummaryWriter(log_directory)
    
    moving_average = int(min(len(train_loader), 100)) # average loss over entire Epoch
    rolling_sum = StreamingMovingAverage(moving_average)
    start_time = time.time()
    start_time_iter = time.time()
    start_time_dekaiter = time.time()
    model.train()
    
    # best (averaged) training loss
    if os.path.exists(os.path.join(output_directory, "best_model")+".txt"):
        best_model_loss = float(str(open(os.path.join(output_directory, "best_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_model_loss = -6.20
    
    # best (validation) MSE on inferred spectrogram.
    if os.path.exists(os.path.join(output_directory, "best_val_model")+".txt"):
        best_MSE = float(str(open(os.path.join(output_directory, "best_val_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_MSE = 9e9
    
    epoch_offset = max(0, int(iteration / len(train_loader)))
    
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("{:,} total parameters in model".format(pytorch_total_params))
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("{:,} trainable parameters.".format(pytorch_total_params))
    
    print(f"Segment Length: {data_config['segment_length']:,}\nBatch Size: {batch_size:,}\nNumber of GPUs: {num_gpus:,}\nSamples/Iter: {data_config['segment_length']*batch_size*num_gpus:,}")
    
    training = True
    while training:
        try:
            if rank == 0:
                epochs_iterator = tqdm(range(epoch_offset, epochs), initial=epoch_offset, total=epochs, smoothing=0.01, desc="Epoch", position=1, unit="epoch")
            else:
                epochs_iterator = range(epoch_offset, epochs)
            # ================ MAIN TRAINING LOOP! ===================
            for epoch in epochs_iterator:
                print(f"Epoch: {epoch}")
                if num_gpus > 1:
                    train_sampler.set_epoch(epoch)
                
                if rank == 0:
                    iters_iterator = tqdm(enumerate(train_loader), desc=" Iter", smoothing=0, total=len(train_loader), position=0, unit="iter", leave=True)
                else:
                    iters_iterator = enumerate(train_loader)
                for i, batch in iters_iterator:
                    # run external code every iter, allows the run to be adjusted without restarts
                    if (i==0 or iteration % param_interval == 0):
                        try:
                            with open("run_every_epoch.py") as f:
                                internal_text = str(f.read())
                                if len(internal_text) > 0:
                                    #code = compile(internal_text, "run_every_epoch.py", 'exec')
                                    ldict = {'iteration': iteration, 'seconds_elapsed': time.time()-start_time}
                                    exec(internal_text, globals(), ldict)
                                else:
                                    print("No Custom code found, continuing without changes.")
                        except Exception as ex:
                            print(f"Custom code FAILED to run!\n{ex}")
                        globals().update(ldict)
                        locals().update(ldict)
                        if show_live_params:
                            print(internal_text)
                    if not iteration % 50: # check actual learning rate every 20 iters (because I sometimes see learning_rate variable go out-of-sync with real LR)
                        learning_rate = optimizer.param_groups[0]['lr']
                    # Learning Rate Schedule
                    if custom_lr:
                        old_lr = learning_rate
                        if iteration < warmup_start:
                            learning_rate = warmup_start_lr
                        elif iteration < warmup_end:
                            learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                        else:
                            if iteration < decay_start:
                                learning_rate = A_ + C_
                            else:
                                iteration_adjusted = iteration - decay_start
                                learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_
                        assert learning_rate > -1e-8, "Negative Learning Rate."
                        if old_lr != learning_rate:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = learning_rate
                    else:
                        scheduler.patience = scheduler_patience
                        scheduler.cooldown = scheduler_cooldown
                        if override_scheduler_last_lr:
                            scheduler._last_lr = override_scheduler_last_lr
                        if override_scheduler_best:
                            scheduler.best = override_scheduler_best
                        if override_scheduler_last_lr or override_scheduler_best:
                            print("scheduler._last_lr =", scheduler._last_lr, "scheduler.best =", scheduler.best, "  |", end='')
                    model.zero_grad()
                    mel, audio, speaker_ids = batch
                    mel = torch.autograd.Variable(mel.cuda(non_blocking=True))
                    audio = torch.autograd.Variable(audio.cuda(non_blocking=True))
                    speaker_ids = speaker_ids.cuda(non_blocking=True).long().squeeze(1)
                    outputs = model(mel, audio, speaker_ids)
                    
                    loss = criterion(outputs)
                    if num_gpus > 1:
                        reduced_loss = reduce_tensor(loss.data, num_gpus).item()
                    else:
                        reduced_loss = loss.item()
                    
                    if fp16_run:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    
                    if (reduced_loss > LossExplosionThreshold) or (math.isnan(reduced_loss)):
                        model.zero_grad()
                        raise LossExplosion(f"\nLOSS EXPLOSION EXCEPTION ON RANK {rank}: Loss reached {reduced_loss} during iteration {iteration}.\n\n\n")
                    
                    if use_grad_clip:
                        if fp16_run:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), grad_clip_thresh)
                        else:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                model.parameters(), grad_clip_thresh)
                        if type(grad_norm) == torch.Tensor:
                            grad_norm = grad_norm.item()
                        is_overflow = math.isinf(grad_norm) or math.isnan(grad_norm)
                    else: is_overflow = False; grad_norm=0.00001
                    
                    optimizer.step()
                    if not is_overflow and rank == 0:
                        # get current Loss Scale of first optimizer
                        loss_scale = amp._amp_state.loss_scalers[0]._loss_scale if fp16_run else 32768
                        
                        if with_tensorboard:
                            if (iteration % 100000 == 0):
                                # plot distribution of parameters
                                for tag, value in model.named_parameters():
                                    tag = tag.replace('.', '/')
                                    logger.add_histogram(tag, value.data.cpu().numpy(), iteration)
                            logger.add_scalar('training_loss', reduced_loss, iteration)
                            logger.add_scalar('training_loss_samples', reduced_loss, iteration*batch_size)
                            if (iteration % 20 == 0):
                                logger.add_scalar('learning.rate', learning_rate, iteration)
                            if (iteration % 10 == 0):
                                logger.add_scalar('duration', ((time.time() - start_time_dekaiter)/10), iteration)
                        
                        average_loss = rolling_sum.process(reduced_loss)
                        if (iteration % 10 == 0):
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective)  {:.2f}s/iter {:.4f}s/item".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), (time.time() - start_time_dekaiter)/10, ((time.time() - start_time_dekaiter)/10)/(batch_size*num_gpus)))
                            start_time_dekaiter = time.time()
                        else:
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective) {}LS".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), loss_scale))
                        start_time_iter = time.time()
                    
                    if rank == 0 and (len(rolling_sum.values) > moving_average-2):
                        if (average_loss+best_model_margin) < best_model_loss:
                            checkpoint_path = os.path.join(output_directory, "best_model")
                            try:
                                save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                            checkpoint_path)
                            except KeyboardInterrupt: # Avoid corrupting the model.
                                save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                            checkpoint_path)
                            text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                            text_file.write(str(average_loss)+"\n"+str(iteration))
                            text_file.close()
                            best_model_loss = average_loss #Only save the model if X better than the current loss.
                    if rank == 0 and iteration > 0 and ((iteration % iters_per_checkpoint == 0) or (os.path.exists(save_file_check_path))):
                        checkpoint_path = f"{output_directory}/waveglow_{iteration}"
                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                        checkpoint_path)
                        if (os.path.exists(save_file_check_path)):
                            os.remove(save_file_check_path)
                    
                    if (iteration % validation_interval == 0):
                        if rank == 0:
                            MSE, MAE = validate(model, loader_STFT, STFTs, logger, iteration, data_config['validation_files'], speaker_lookup, sigma, output_directory, data_config)
                            if scheduler:
                                MSE = torch.tensor(MSE, device='cuda')
                                if num_gpus > 1:
                                    broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                                if MSE < best_MSE:
                                    checkpoint_path = os.path.join(output_directory, "best_val_model")
                                    try:
                                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                                    checkpoint_path)
                                    except KeyboardInterrupt: # Avoid corrupting the model.
                                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                                    checkpoint_path)
                                    text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                                    text_file.write(str(MSE.item())+"\n"+str(iteration))
                                    text_file.close()
                                    best_MSE = MSE.item() #Only save the model if X better than the current loss.
                        else:
                            if scheduler:
                                MSE = torch.zeros(1, device='cuda')
                                broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                        learning_rate = optimizer.param_groups[0]['lr'] #check actual learning rate (because I sometimes see learning_rate variable go out-of-sync with real LR)
                    iteration += 1
            training = False # exit the While loop
        
        except LossExplosion as ex: # print Exception and continue from checkpoint. (turns out it takes < 4 seconds to restart like this, f*****g awesome)
            print(ex) # print Loss
            checkpoint_path = os.path.join(output_directory, "best_model")
            assert os.path.exists(checkpoint_path), "best_val_model must exist for automatic restarts"
            
            # clearing VRAM for load checkpoint
            audio = mel = speaker_ids = loss = None
            torch.cuda.empty_cache()
            
            model.eval()
            model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model, optimizer, scheduler, fp16_run)
            learning_rate = optimizer.param_groups[0]['lr']
            epoch_offset = max(0, int(iteration / len(train_loader)))
            model.train()
            iteration += 1
            pass # and continue training.
Exemple #18
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    print('Enabled distributed training.')

    rank, world_size = init_dist(backend='nccl', port=args.port)
    args.rank = rank
    args.world_size = world_size

    # create model
    print("=> creating model '{}'".format(args.model))
    if 'resnetv1sn' in args.model:
        model = models.__dict__[args.model](
            using_moving_average=args.using_moving_average,
            using_bn=args.using_bn,
            last_gamma=args.last_gamma)
    else:
        model = models.__dict__[args.model](
            using_moving_average=args.using_moving_average,
            using_bn=args.using_bn)

    model.cuda()
    broadcast_params(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # auto resume from a checkpoint
    model_dir = args.model_dir
    start_epoch = 0
    if args.rank == 0 and not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir,
                                             model,
                                             optimizer=optimizer)
    if args.rank == 0:
        writer = SummaryWriter(model_dir)
    else:
        writer = None

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = ImagenetDataset(
        args.train_root, args.train_source,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ColorAugmentation(),
            normalize,
        ]))
    val_dataset = ImagenetDataset(
        args.val_root, args.val_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size // args.world_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=False,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size // args.world_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=False,
                            sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, lr_scheduler, epoch,
              writer)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer)

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                model_dir, {
                    'epoch': epoch + 1,
                    'model': args.model,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
Exemple #19
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes.
            args.rank = args.rank * ngpus_per_node + gpu

        dist.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank
        )

    # load model here
    # model = maskrcnn001(num_classes=2)

    model = arch(num_classes=2)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all availabel devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per 
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set.
            model = DistributedDataParallel(model) 
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divice and allocate batch_size to all availabel GPUs
        # model = torch.nn.DataParallel(model).cuda()
        model = model.cuda()

    if args.distributed:
        # model = DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    else:
        model_without_ddp = model

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params,
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # lr_scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)
    lr_scheduler = MultiStepLR(optimizer, milestones=[20000, 40000], gamma=0.1)

    # ================================
    # resume RESUME CHECKPOINT
    if IS_SM:  # load latest checkpoints 
        checkpoint_list = os.listdir(checkpoint_dir)

        logger.info("=> Checking checkpoints dir.. {}".format(checkpoint_dir))
        logger.info(checkpoint_list)

        latest_path_parent = ""
        latest_path = ""
        latest_iter_num = -1

        for checkpoint_path in natsorted(glob.glob(os.path.join(checkpoint_dir, "*.pth"))):
            checkpoint_name = os.path.basename(checkpoint_path)
            logger.info("Found checkpoint {}".format(checkpoint_name))
            iter_num = int(os.path.splitext(checkpoint_name)[0].split("_")[-1])

            if iter_num > latest_iter_num:
                latest_path_parent = latest_path
                latest_path = checkpoint_path
                latest_iter_num = iter_num 

        logger.info("> latest checkpoint is {}".format(latest_path))

        if latest_path_parent:
            logger.info("=> loading checkpoint {}".format(latest_path_parent))
            checkpoint = torch.load(latest_path_parent, map_location="cpu")
            model_without_ddp.load_state_dict(checkpoint["model"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

            args.start_epoch = checkpoint["epoch"]
            args.iter_num = checkpoint["iter_num"]
            logger.info("==> args.iter_num is {}".format(args.iter_num))

    if args.test_only:
        evaluate(model, data_loader_test, device=device)
        return
    
    logger.info("==================================")
    logger.info("Create dataset with root_dir={}".format(args.train_data_path))
    assert os.path.exists(args.train_data_path), "root_dir does not exists!"
    train_set = TableBank(root_dir=args.train_data_path)

    if args.distributed:
        train_sampler = DistributedSampler(train_set)
    else:
        train_sampler = RandomSampler(train_set)

    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(
            train_set,
            k=args.aspect_ratio_group_factor
        )
        train_batch_sampler = GroupedBatchSampler(
            train_sampler,
            group_ids,
            args.batch_size
        )
    else:
        train_batch_sampler = BatchSampler(
            train_sampler,
            args.batch_size,
            drop_last=True
        )

    logger.info("Create data_loader.. with batch_size = {}".format(args.batch_size))
    train_loader = DataLoader(
        train_set,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=utils.collate_fn,
        pin_memory=True
    )

    logger.info("Start training.. ")

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_one_epoch(
            model=model,
            arch=arch,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            data_loader=train_loader,
            device=args.gpu,
            epoch=epoch,
            print_freq=args.print_freq,
            ngpus_per_node=4,
            model_without_ddp=model_without_ddp,
            args=args
        )
Exemple #20
0
def model_trainer(args):
    # Load MNIST
    data_root = './'
    train_set = MNIST(root=data_root, download=True, train=True, transform=ToTensor())
    train_sampler = DistributedSampler(train_set)
    same_seeds(args.seed_num)
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=(train_sampler is None), pin_memory=True, sampler=train_sampler)
    valid_set = MNIST(root=data_root, download=True, train=False, transform=ToTensor())
    valid_loader = DataLoader(valid_set, batch_size=args.batch_size, shuffle=False, pin_memory=True)

    print(f'Now Training: {args.exp_name}')
    # Load model
    same_seeds(args.seed_num)
    model = Toy_Net()
    model = model.to(args.local_rank)

    # Model parameters
    os.makedirs(f'./experiment_model/', exist_ok=True)
    latest_model_path = f'./experiment_model/{args.exp_name}'
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, nesterov=True)
    lookahead = Lookahead(optimizer=optimizer, k=10, alpha=0.5)
    loss_function = nn.CrossEntropyLoss()
    if args.local_rank == 0:
        best_valid_acc = 0

    # Callbacks
    warm_up = lambda epoch: epoch / args.warmup_epochs if epoch <= args.warmup_epochs else 1
    scheduler_wu = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=warm_up)
    scheduler_re = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.1, patience=10, verbose=True)
    early_stopping = EarlyStopping(patience=50, verbose=True)
            
    # Apex
    #amp.register_float_function(torch, 'sigmoid')   # register for uncommonly function
    model, apex_optimizer = amp.initialize(model, optimizers=lookahead, opt_level="O1")

    # Build training model
    parallel_model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

    # Train model
    if args.local_rank == 0:
        tb = SummaryWriter(f'./tensorboard_runs/{args.exp_name}')
    #apex_optimizer.zero_grad()
    #apex_optimizer.step()
    for epoch in range(args.epochs):
        epoch_start_time = time.time()
        train_loss, train_acc = 0., 0.
        valid_loss, valid_acc = 0., 0.
        train_num, valid_num = 0, 0
        train_sampler.set_epoch(epoch)

        # Train
        parallel_model.train()
        # Warm up
        #if epoch < args.warmup_epochs:
        #    scheduler_wu.step()
        for image, target in tqdm(train_loader, total=len(train_loader)):
            apex_optimizer.zero_grad()
            image = image.to(args.local_rank)
            target = target.to(args.local_rank, dtype=torch.long)
            outputs = parallel_model(image)
            predict = torch.argmax(outputs, dim=1)
            batch_loss = loss_function(outputs, target)
            batch_loss /= len(outputs)
            # Apex
            with amp.scale_loss(batch_loss, apex_optimizer) as scaled_loss:
                scaled_loss.backward()
            apex_optimizer.step()

            # Calculate loss & acc
            train_loss += batch_loss.item() * len(image)
            train_acc += (predict == target).sum().item()
            train_num += len(image)

        train_loss = train_loss / train_num
        train_acc = train_acc / train_num
        curr_lr = apex_optimizer.param_groups[0]['lr']
        if args.local_rank == 0:
            tb.add_scalar('LR', curr_lr, epoch)
            tb.add_scalar('Loss/train', train_loss, epoch)
            tb.add_scalar('Acc/train', train_acc, epoch)

        # Valid
        parallel_model.eval()
        with torch.no_grad():
            for image, target in tqdm(valid_loader, total=len(valid_loader)):
                image = image.to(args.local_rank)
                target = target.to(args.local_rank, dtype=torch.long)
                outputs = parallel_model(image)
                predict = torch.argmax(outputs, dim=1)
                batch_loss = loss_function(outputs, target)
                batch_loss /= len(outputs)
                    
                # Calculate loss & acc
                valid_loss += batch_loss.item() * len(image)
                valid_acc += (predict == target).sum().item()
                valid_num += len(image)

        valid_loss = valid_loss / valid_num
        valid_acc = valid_acc / valid_num
        if args.local_rank == 0:
            tb.add_scalar('Loss/valid', valid_loss, epoch)
            tb.add_scalar('Acc/valid', valid_acc, epoch)
            
        # Print result
        print(f'epoch: {epoch:03d}/{args.epochs}, time: {time.time()-epoch_start_time:.2f}s, learning_rate: {curr_lr}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, valid_loss: {valid_loss:.4f}, valid_acc: {valid_acc:.4f}')

        # Learning_rate callbacks
        if epoch <= args.warmup_epochs:
            scheduler_wu.step()
        scheduler_re.step(valid_loss)
        early_stopping(valid_loss)
        if early_stopping.early_stop:
            break

        # Save_checkpoint
        if args.local_rank == 0:
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                torch.save(parallel_model.module.state_dict(), f'{latest_model_path}.pt')

    if args.local_rank == 0:
        tb.close()
Exemple #21
0
def train(model, local_rank):
    if local_rank == 0:
        writer = SummaryWriter('train')
        writer_val = SummaryWriter('validate')
    step = 0
    nr_eval = 0
    dataset = VimeoDataset('train')
    sampler = DistributedSampler(dataset)
    train_data = DataLoader(dataset,
                            batch_size=args.batch_size,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True,
                            sampler=sampler)
    args.step_per_epoch = train_data.__len__()
    dataset_val = VimeoDataset('validation')
    val_data = DataLoader(dataset_val,
                          batch_size=16,
                          pin_memory=True,
                          num_workers=8)
    print('training...')
    time_stamp = time.time()
    for epoch in range(args.epoch):
        sampler.set_epoch(epoch)
        for i, data in enumerate(train_data):
            data_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            data_gpu = data.to(device, non_blocking=True) / 255.
            imgs = data_gpu[:, :6]
            gt = data_gpu[:, 6:9]
            learning_rate = get_learning_rate(step)
            pred, info = model.update(imgs, gt, learning_rate, training=True)
            train_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            if step % 200 == 1 and local_rank == 0:
                writer.add_scalar('learning_rate', learning_rate, step)
                writer.add_scalar('loss/l1', info['loss_l1'], step)
                writer.add_scalar('loss/tea', info['loss_tea'], step)
                writer.add_scalar('loss/distill', info['loss_distill'], step)
            if step % 1000 == 1 and local_rank == 0:
                gt = (gt.permute(0, 2, 3, 1).detach().cpu().numpy() *
                      255).astype('uint8')
                mask = (torch.cat((info['mask'], info['mask_tea']), 3).permute(
                    0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                pred = (pred.permute(0, 2, 3, 1).detach().cpu().numpy() *
                        255).astype('uint8')
                merged_img = (info['merged_tea'].permute(
                    0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
                flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy()
                flow1 = info['flow_tea'].permute(0, 2, 3,
                                                 1).detach().cpu().numpy()
                for i in range(5):
                    imgs = np.concatenate((merged_img[i], pred[i], gt[i]),
                                          1)[:, :, ::-1]
                    writer.add_image(str(i) + '/img',
                                     imgs,
                                     step,
                                     dataformats='HWC')
                    writer.add_image(
                        str(i) + '/flow',
                        np.concatenate(
                            (flow2rgb(flow0[i]), flow2rgb(flow1[i])), 1),
                        step,
                        dataformats='HWC')
                    writer.add_image(str(i) + '/mask',
                                     mask[i],
                                     step,
                                     dataformats='HWC')
                writer.flush()
            if local_rank == 0:
                print(
                    'epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(
                        epoch, i, args.step_per_epoch, data_time_interval,
                        train_time_interval, info['loss_l1']))
            step += 1
        nr_eval += 1
        if nr_eval % 5 == 0:
            evaluate(model, val_data, step, local_rank, writer_val)
        model.save_model(log_path, local_rank)
        dist.barrier()
Exemple #22
0
def main():
    model: nn.Module = resnet34(num_classes=10).cuda()

    # Set variables here. These are just for demonstration so no need for argparse.
    batch_size = 1024  # This is the true min-batch size, thanks to DistributedSampler.
    num_workers_per_process = 2  # Workers launched by each process started by horovodrun command.
    lr = 0.1
    momentum = 0.9
    weight_decay = 1E-4
    root_rank = 0
    num_epochs = 10

    train_transform = Compose(
        [RandomHorizontalFlip(),
         RandomCrop(size=32, padding=4),
         ToTensor()])
    train_dataset = CIFAR10(root='./data',
                            train=True,
                            transform=train_transform,
                            download=True)
    test_dataset = CIFAR10(root='./data',
                           train=False,
                           transform=ToTensor(),
                           download=True)

    # Distributed samplers are necessary for accurately dividing the dataset among the processes.
    # It also controls mini-batch size effectively between processes.
    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=hvd.size(),
                                       rank=hvd.rank())
    test_sampler = DistributedSampler(test_dataset,
                                      num_replicas=hvd.size(),
                                      rank=hvd.rank())

    # Create iterable to allow manual unwinding of the for loop.
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              sampler=train_sampler,
                              num_workers=num_workers_per_process,
                              pin_memory=True)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=batch_size,
                             sampler=test_sampler,
                             num_workers=num_workers_per_process,
                             pin_memory=True)

    loss_func = nn.CrossEntropyLoss()

    # Writing separate log files for each process for comparison. Verified that models are different.
    writer = SummaryWriter(log_dir=f'./logs/{hvd.local_rank()}',
                           comment='Summary writer for run.')

    # Optimizer must be distributed for the Ring-AllReduce.
    optimizer = SGD(params=model.parameters(),
                    lr=lr,
                    momentum=momentum,
                    weight_decay=weight_decay)
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())

    def warm_up(epoch: int):  # Learning rate warm-up.
        if epoch < 5:
            return (epoch + 1) / 5
        elif epoch < 75:
            return 1
        elif epoch < 90:
            return 0.1
        else:
            return 0.01

    scheduler = LambdaLR(
        optimizer, lr_lambda=warm_up)  # Learning rate scheduling with warm-up.

    # Broadcast the model's parameters to all devices.
    hvd.broadcast_parameters(model.state_dict(), root_rank=root_rank)

    for epoch in range(num_epochs):
        print(epoch)
        torch.autograd.set_grad_enabled = True  # Training mode.
        train_sampler.set_epoch(
            epoch
        )  # Set epoch to sampler for proper shuffling of training set.
        for inputs, targets in train_loader:
            inputs: Tensor = inputs.cuda(non_blocking=True)
            targets: Tensor = targets.cuda(non_blocking=True)

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = loss_func(outputs, targets)
            loss.backward()
            optimizer.step()

        torch.autograd.set_grad_enabled = False  # Evaluation mode.
        for step, (inputs, targets) in enumerate(test_loader):
            inputs: Tensor = inputs.cuda(non_blocking=True)
            targets: Tensor = targets.cuda(non_blocking=True)
            outputs = model(inputs)
            loss = loss_func(outputs, targets)
            writer.add_scalar(tag='val_loss',
                              scalar_value=loss.item(),
                              global_step=step)

        scheduler.step()  # Scheduler works fine on DistributedOptimizer.
Exemple #23
0
def train(gpu, config):
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=config['num_gpus'],
                            rank=gpu)
    torch.cuda.set_device(gpu)
    """ 
        @ build the dataset for training
    """
    dataset = get_data(config)
    trainset = dataset(config, "train")
    testset = dataset(config, "test")
    sampler_train = DistributedSampler(trainset,
                                       num_replicas=config['num_gpus'],
                                       rank=gpu)
    sampler_val = DistributedSampler(testset,
                                     num_replicas=config['num_gpus'],
                                     rank=gpu)

    batch_size = config['batch_size']
    loader_train = DataLoader(dataset=trainset,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=config['num_threads'],
                              pin_memory=True,
                              sampler=sampler_train,
                              drop_last=True)
    loader_val = DataLoader(dataset=testset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1,
                            pin_memory=True,
                            sampler=sampler_val,
                            drop_last=True)
    model = UNet(config["in_channels"],
                 config["out_channels"],
                 post_processing=True)
    model.cuda(gpu)
    mask_sampling = masksamplingv2()
    """  @ init parameter
    """

    save_folder = os.path.join(
        config['save_root'], 'batch_{}_lr_{}'.format(config['batch_size'],
                                                     config['lr']))
    best_epoch = 0
    lowest_loss = 0.
    resume = 0
    print('=>Save folder: {}\n'.format(save_folder))
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
    optimizer = define_optim(config['optimizer'], model.parameters(),
                             float(config['lr']), 0)

    criterion_1 = define_loss(config['loss_type'])
    criterion_2 = define_loss("Multimse")
    scheduler = define_scheduler(optimizer, config)
    """
        @ justify the resume model
    """
    if config['resume'] != 'None':
        checkpoint = torch.load(config['resume'],
                                map_location=torch.device('cpu'))
        resume = checkpoint['epoch']
        lowest_loss = checkpoint['loss']
        best_epoch = checkpoint['best_epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level='O0',
                                          verbosity=0)
        amp.load_state_dict(checkpoint['amp'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            resume, checkpoint['epoch']))
        del checkpoint
    log_file = 'log_train_start_{}.txt'.format(resume)
    """
        @ convert model to multi-gpus modes for training
    """
    model = apex.parallel.convert_syncbn_model(model)
    if config['resume'] == 'None':
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level='O0',
                                          verbosity=0)
    model = DDP(model)
    if gpu == 0:
        sys.stdout = Logger(os.path.join(save_folder, log_file))
    print("Number of parameters in model is {:.3f}M".format(
        sum(tensor.numel() for tensor in model.parameters()) / 1e6))
    """
        @ start to train
    """
    for epoch in range(resume + 1, config['epoches'] + 1):
        print('=> Starch Epoch {}\n'.format(epoch))
        print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
        print('learning rate is set to {}.\n'.format(
            optimizer.param_groups[0]['lr']))
        model.train()
        sampler_train.set_epoch(epoch)
        batch_time = AverageMeter()
        losses = AverageMeter()
        metric_train = Metrics()
        rmse_train = AverageMeter()
        mae_train = AverageMeter()
        time_snap = time.time()
        for i, inputs in tqdm(enumerate(loader_train)):
            gt, noise = inputs['gt'].cuda(gpu), inputs['noise'].cuda(gpu)
            optimizer.zero_grad()
            """ update the train inputs
            """
            # patten = np.random.randint(0, 4, 1)
            patten = torch.randint(0, 8, (1, ))
            redinput, blueinput = mask_sampling(noise, patten)

            # redinput, blueinput = generator(noise, mask1, mask2)
            output = model(redinput)
            loss = criterion_1(output, blueinput)
            fulloutput = model(noise)
            redoutput, blueoutput = mask_sampling(fulloutput, patten)
            # redoutput, blueoutput = generator(fulloutput, mask1, mask2)

            loss2 = criterion_2(output, blueinput, redoutput, blueoutput)
            losssum = config["gamma"] * loss2 + loss
            with amp.scale_loss(losssum, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()
            "@ map-reduce tensor"
            rt = reduce_tensor(losssum.data)
            torch.cuda.synchronize()
            losses.update(rt.item(), loader_train.batch_size)
            metric_train.calculate(fulloutput.detach(), gt)
            rmse_train.update(metric_train.get_metric('mse'), metric_train.num)
            mae_train.update(metric_train.get_metric('mae'), metric_train.num)
            batch_time.update(time.time() - time_snap)
            time_snap = time.time()
            if (i + 1) % config['print_freq'] == 0:
                if gpu == 0:
                    print('Epoch: [{0}][{1}/{2}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.6f} ({loss.avg:.6f})\t'
                          'Metric {rmse_train.val:.6f} ({rmse_train.avg:.6f})'.
                          format(epoch,
                                 i + 1,
                                 len(loader_train),
                                 batch_time=batch_time,
                                 loss=losses,
                                 rmse_train=rmse_train))
            if (i + 1) % config['save_freq'] == 0:
                print('=> Start sub-selection validation set')
                rmse, mae = val(model, loader_val, epoch, gpu)
                model.train()
                if gpu == 0:
                    print("===> Average RMSE score on selection set is {:.6f}".
                          format(rmse))
                    print("===> Average MAE score on selection set is {:.6f}".
                          format(mae))
                    print(
                        "===> Last best score was RMSE of {:.6f} in epoch {}".
                        format(lowest_loss, best_epoch))

                    if rmse > lowest_loss:
                        lowest_loss = rmse
                        best_epoch = epoch
                        states = {
                            'epoch': epoch,
                            'best_epoch': best_epoch,
                            'loss': lowest_loss,
                            'state_dict': model.module.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'amp': amp.state_dict()
                        }

                        save_checkpoints(states, save_folder, epoch, gpu, True)
        # save checkpoints
        print('=> Start selection validation set')
        rmse, mae = val(model, loader_val, epoch, gpu)
        model.train()
        if gpu == 0:
            print("===> Average RMSE score on selection set is {:.6f}".format(
                rmse))
            print("===> Average MAE score on selection set is {:.6f}".format(
                mae))
            print("===> Last best score was RMSE of {:.6f} in epoch {}".format(
                lowest_loss, best_epoch))
            if rmse > lowest_loss:
                best_epoch = epoch
                lowest_loss = rmse
                states = {
                    'epoch': epoch,
                    'best_epoch': best_epoch,
                    'loss': lowest_loss,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'amp': amp.state_dict()
                }
                save_checkpoints(states, save_folder, epoch, gpu, True)

        if config['lr_policy'] == 'plateau':
            scheduler.step(rmse)
        else:
            scheduler.step()
        # if (epoch) % 10 == 0:
        #     config["gamma"] += 0.5
        print('=>> the model training finish!')
Exemple #24
0
def main():
    if args.config_path:
        with open(args.config_path, 'rb') as fp:
            config = CfgNode.load_cfg(fp)
    else:
        config = None

    ckpts_save_dir = args.ckpt_save_dir
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    test_model = None
    max_epoch = config.TRAIN.NUM_EPOCHS
    if 'test' in args:
        test_model = args.test
    print('data folder: ', args.data_dir)

    torch.backends.cudnn.benchmark = True

    # WORLD_SIZE Generated by torch.distributed.launch.py
    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    is_distributed = num_gpus > 1
    if is_distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl", init_method="env://",
        )

    model = get_model(config)
    model_loss = ModelLossWraper(model,
                                 config.TRAIN.CLASS_WEIGHTS,
                                 config.MODEL.IS_DISASTER_PRED,
                                 config.MODEL.IS_SPLIT_LOSS,
                                 )

    if is_distributed:
        model_loss = nn.SyncBatchNorm.convert_sync_batchnorm(model_loss)
        model_loss = nn.parallel.DistributedDataParallel(
            model_loss, device_ids=[args.local_rank], output_device=args.local_rank
        )

    trainset = XView2Dataset(args.data_dir, rgb_bgr='rgb',
                             preprocessing={'flip': True,
                                            'scale': config.TRAIN.MULTI_SCALE,
                                            'crop': config.TRAIN.CROP_SIZE,
                                            })

    if is_distributed:
        train_sampler = DistributedSampler(trainset)
    else:
        train_sampler = None

    trainset_loader = torch.utils.data.DataLoader(trainset, batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
                                                  shuffle=train_sampler is None, pin_memory=True, drop_last=True,
                                                  sampler=train_sampler, num_workers=num_gpus)

    model.train()

    lr_init = config.TRAIN.LR
    optimizer = torch.optim.SGD([{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': lr_init}],
                                lr=lr_init,
                                momentum=0.9,
                                weight_decay=0.,
                                nesterov=False,
                                )

    start_epoch = 0
    losses = AverageMeter()
    model.train()
    num_iters = max_epoch * len(trainset_loader)
    for epoch in range(start_epoch, max_epoch):
        if is_distributed:
            train_sampler.set_epoch(epoch)
        cur_iters = epoch * len(trainset_loader)

        for i, samples in enumerate(trainset_loader):
            lr = adjust_learning_rate(optimizer, lr_init, num_iters, i + cur_iters)

            inputs_pre = samples['pre_img']
            inputs_post = samples['post_img']
            target = samples['mask_img']
            disaster_target = samples['disaster']

            loss = model_loss(inputs_pre, inputs_post, target, disaster_target)

            loss_sum = torch.sum(loss).detach().cpu()
            if np.isnan(loss_sum) or np.isinf(loss_sum):
                print('check')
            losses.update(loss_sum, 4)  # batch size

            loss = torch.sum(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if args.local_rank == 0 and i % 10 == 0:
                logger.info('epoch: {0}\t'
                            'iter: {1}/{2}\t'
                            'lr: {3:.6f}\t'
                            'loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                    epoch + 1, i + 1, len(trainset_loader), lr, loss=losses))

        if args.local_rank == 0:
            if (epoch + 1) % 50 == 0 and test_model is None:
                torch.save({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(ckpts_save_dir, 'hrnet_%s' % (epoch + 1)))
def main():
    args = parse_args()

    logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(config)

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)
    distributed = len(gpus) > 1
    device = torch.device('cuda:{}'.format(args.local_rank))

    # build model
    model = eval('models.' + config.MODEL.NAME + '.get_seg_model')(config)

    if args.local_rank == 0:
        # provide the summary of model
        dump_input = torch.rand(
            (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]))
        logger.info(get_model_summary(model.to(device), dump_input.to(device)))

        # copy model file
        this_dir = os.path.dirname(__file__)
        models_dst_dir = os.path.join(final_output_dir, 'models')
        if os.path.exists(models_dst_dir):
            shutil.rmtree(models_dst_dir)
        shutil.copytree(os.path.join(this_dir, '../lib/models'),
                        models_dst_dir)

    if distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(
            backend="nccl",
            init_method="env://",
        )

    # prepare data
    crop_size = (config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    train_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TRAIN_SET,
        num_samples=None,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=config.TRAIN.MULTI_SCALE,
        flip=config.TRAIN.FLIP,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TRAIN.BASE_SIZE,
        crop_size=crop_size,
        downsample_rate=config.TRAIN.DOWNSAMPLERATE,
        scale_factor=config.TRAIN.SCALE_FACTOR)

    if distributed:
        train_sampler = DistributedSampler(train_dataset)
    else:
        train_sampler = None

    trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE and train_sampler is None,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True,
        sampler=train_sampler)

    if config.DATASET.EXTRA_TRAIN_SET:
        extra_train_dataset = eval('datasets.' + config.DATASET.DATASET)(
            root=config.DATASET.ROOT,
            list_path=config.DATASET.EXTRA_TRAIN_SET,
            num_samples=None,
            num_classes=config.DATASET.NUM_CLASSES,
            multi_scale=config.TRAIN.MULTI_SCALE,
            flip=config.TRAIN.FLIP,
            ignore_label=config.TRAIN.IGNORE_LABEL,
            base_size=config.TRAIN.BASE_SIZE,
            crop_size=crop_size,
            downsample_rate=config.TRAIN.DOWNSAMPLERATE,
            scale_factor=config.TRAIN.SCALE_FACTOR)

        if distributed:
            extra_train_sampler = DistributedSampler(extra_train_dataset)
        else:
            extra_train_sampler = None

        extra_trainloader = torch.utils.data.DataLoader(
            extra_train_dataset,
            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
            shuffle=config.TRAIN.SHUFFLE and extra_train_sampler is None,
            num_workers=config.WORKERS,
            pin_memory=True,
            drop_last=True,
            sampler=extra_train_sampler)

    test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
    test_dataset = eval('datasets.' + config.DATASET.DATASET)(
        root=config.DATASET.ROOT,
        list_path=config.DATASET.TEST_SET,
        num_samples=config.TEST.NUM_SAMPLES,
        num_classes=config.DATASET.NUM_CLASSES,
        multi_scale=False,
        flip=False,
        ignore_label=config.TRAIN.IGNORE_LABEL,
        base_size=config.TEST.BASE_SIZE,
        crop_size=test_size,
        center_crop_test=config.TEST.CENTER_CROP_TEST,
        downsample_rate=1)

    if distributed:
        test_sampler = DistributedSampler(test_dataset)
    else:
        test_sampler = None

    testloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True,
        sampler=test_sampler)

    # criterion
    if config.LOSS.USE_OHEM:
        criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                     thres=config.LOSS.OHEMTHRES,
                                     min_kept=config.LOSS.OHEMKEEP,
                                     weight=train_dataset.class_weights)
    else:
        criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                 weight=train_dataset.class_weights)

    model = FullModel(model, criterion)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(device)
    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[args.local_rank],
                                                output_device=args.local_rank)

    # optimizer
    if config.TRAIN.OPTIMIZER == 'sgd':
        optimizer = torch.optim.SGD(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 model.parameters()),
                'lr': config.TRAIN.LR
            }],
            lr=config.TRAIN.LR,
            momentum=config.TRAIN.MOMENTUM,
            weight_decay=config.TRAIN.WD,
            nesterov=config.TRAIN.NESTEROV,
        )
    else:
        raise ValueError('Only Support SGD optimizer')

    epoch_iters = np.int(train_dataset.__len__() /
                         config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
    best_mIoU = 0
    last_epoch = 0
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir, 'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file,
                                    map_location=lambda storage, loc: storage)
            best_mIoU = checkpoint['best_mIoU']
            last_epoch = checkpoint['epoch']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters
    extra_iters = config.TRAIN.EXTRA_EPOCH * epoch_iters

    for epoch in range(last_epoch, end_epoch):
        if distributed:
            train_sampler.set_epoch(epoch)
        if epoch >= config.TRAIN.END_EPOCH:
            train(config, epoch - config.TRAIN.END_EPOCH,
                  config.TRAIN.EXTRA_EPOCH, epoch_iters, config.TRAIN.EXTRA_LR,
                  extra_iters, extra_trainloader, optimizer, model,
                  writer_dict, device)
        else:
            train(config, epoch, config.TRAIN.END_EPOCH, epoch_iters,
                  config.TRAIN.LR, num_iters, trainloader, optimizer, model,
                  writer_dict, device)

        valid_loss, mean_IoU, IoU_array = validate(config, testloader, model,
                                                   writer_dict, device)

        if args.local_rank == 0:
            logger.info(
                '=> saving checkpoint to {}'.format(final_output_dir +
                                                    'checkpoint.pth.tar'))
            torch.save(
                {
                    'epoch': epoch + 1,
                    'best_mIoU': best_mIoU,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(final_output_dir, 'checkpoint.pth.tar'))

            if mean_IoU > best_mIoU:
                best_mIoU = mean_IoU
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'best.pth'))
            msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(
                valid_loss, mean_IoU, best_mIoU)
            logging.info(msg)
            logging.info(IoU_array)

            if epoch == end_epoch - 1:
                torch.save(model.module.state_dict(),
                           os.path.join(final_output_dir, 'final_state.pth'))

                writer_dict['writer'].close()
                end = timeit.default_timer()
                logger.info('Hours: %d' % np.int((end - start) / 3600))
                logger.info('Done')
Exemple #26
0
    def train_epoch(self,
                    method,
                    train_dataset,
                    train_collate_fn,
                    batch_size,
                    epoch,
                    optimizer,
                    scheduler=None):
        self.model.train()
        if torch.cuda.is_available():
            sampler = DistributedSampler(train_dataset)
            sampler.set_epoch(epoch)
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                collate_fn=train_collate_fn,
                batch_size=batch_size,
                sampler=sampler,
                pin_memory=True)
        else:
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                collate_fn=train_collate_fn,
                batch_size=batch_size,
                shuffle=True,
                pin_memory=True)

        start_time = time.time()
        count_batch = 0
        bloss = 0
        for j, data in enumerate(train_loader, 0):
            if torch.cuda.is_available():
                data_cuda = dict()
                for key, value in data.items():
                    if isinstance(value, torch.Tensor):
                        data_cuda[key] = value.cuda()
                    else:
                        data_cuda[key] = value
                data = data_cuda
            count_batch += 1

            bloss = self.train_batch(epoch,
                                     data,
                                     method=method,
                                     optimizer=optimizer,
                                     scheduler=scheduler)
            if self.local_rank == None or self.local_rank == 0:
                if j > 0 and j % 100 == 0:
                    elapsed_time = time.time() - start_time
                    if scheduler is not None:
                        print('Method', method, 'Epoch', epoch, 'Batch ',
                              count_batch, 'Loss ', bloss, 'Time ',
                              elapsed_time, 'Learning rate ',
                              scheduler.get_last_lr())
                    else:
                        print('Method', method, 'Epoch', epoch, 'Batch ',
                              count_batch, 'Loss ', bloss, 'Time ',
                              elapsed_time)

                sys.stdout.flush()

        if self.accumulation_count % self.accumulation_steps != 0:
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            optimizer.zero_grad()
        elapsed_time = time.time() - start_time
        if self.local_rank == None or self.local_rank == 0:
            if scheduler is not None:
                print('Method', method, 'Epoch', epoch, 'Batch ', count_batch,
                      'Loss ', bloss, 'Time ', elapsed_time, 'Learning rate ',
                      scheduler.get_last_lr())
            else:
                print('Method', method, 'Epoch', epoch, 'Batch ', count_batch,
                      'Loss ', bloss, 'Time ', elapsed_time)
        sys.stdout.flush()
def train_validation_loops(net, logger, args, path_dict, writer, test_mode):

    # %% Load curriculum objects
    path_cur_obj = os.path.join(path_dict['repo_root'], 'cur_objs',
                                args['mode'],
                                'cond_' + args['cur_obj'] + '.pkl')

    with open(path_cur_obj, 'rb') as f:
        train_obj, valid_obj, test_obj = pickle.load(f)

    # %% Specify flags of importance
    train_obj.augFlag = args['aug_flag']
    valid_obj.augFlag = False
    test_obj.augFlag = False

    train_obj.equi_var = args['equi_var']
    valid_obj.equi_var = args['equi_var']
    test_obj.equi_var = args['equi_var']

    # %% Modify path information
    train_obj.path2data = path_dict['path_data']
    valid_obj.path2data = path_dict['path_data']
    test_obj.path2data = path_dict['path_data']

    # %% Create distributed samplers
    train_sampler = DistributedSampler(
        train_obj,
        rank=args['local_rank'],
        shuffle=False,
        num_replicas=args['world_size'],
    )

    valid_sampler = DistributedSampler(
        valid_obj,
        rank=args['local_rank'],
        shuffle=False,
        num_replicas=args['world_size'],
    )

    test_sampler = DistributedSampler(
        test_obj,
        rank=args['local_rank'],
        shuffle=False,
        num_replicas=args['world_size'],
    )

    # %% Define dataloaders
    logger.write('Initializing loaders')
    if not test_mode:
        train_loader = DataLoader(
            train_obj,
            shuffle=False,
            num_workers=args['workers'],
            drop_last=True,
            pin_memory=True,
            batch_size=args['batch_size'],
            sampler=train_sampler if args['do_distributed'] else None,
        )

        valid_loader = DataLoader(
            valid_obj,
            shuffle=False,
            num_workers=args['workers'],
            drop_last=True,
            pin_memory=True,
            batch_size=args['batch_size'],
            sampler=valid_sampler if args['do_distributed'] else None,
        )

    test_loader = DataLoader(
        test_obj,
        shuffle=False,
        num_workers=0,
        drop_last=True,
        batch_size=args['batch_size'],
        sampler=test_sampler if args['do_distributed'] else None,
    )

    # %% Early stopping criterion
    early_stop = EarlyStopping(
        patience=args['early_stop'],
        verbose=True,
        delta=0.001,  # 0.1% improvement needed
        rank_cond=args['local_rank'] == 0 if args['do_distributed'] else True,
        mode='max',
        fName='best_model.pt',
        path_save=path_dict['results'],
    )

    # %% Define alpha and beta scalars
    if args['curr_learn_losses']:
        alpha_scalar = mod_scalar([0, args['epochs']], [0, 1])
        beta_scalar = mod_scalar([10, 20], [0, 1])

    # %% Optimizer
    param_list = [
        param for name, param in net.named_parameters() if 'adv' not in name
    ]
    optimizer = torch.optim.Adam(param_list, lr=args['lr'], amsgrad=False)

    if args['adv_DG']:
        param_list = [
            param for name, param in net.named_parameters() if 'adv' in name
        ]
        optimizer_disc = torch.optim.Adam(param_list,
                                          lr=args['lr'],
                                          amsgrad=True)
    else:
        optimizer_disc = False

    # %% Loops and what not

    # Create a checkpoint based on current scores
    checkpoint = {}
    checkpoint['args'] = args  # Save arguments

    # Randomize the dataset again the next time you exit
    # to the main loop.
    args['time_to_update'] = True

    if test_mode:

        logging.info('Entering test only mode ...')
        args['alpha'] = 0.5
        args['beta'] = 0.5

        test_result = forward(net, [],
                              logger,
                              test_loader,
                              optimizer,
                              args,
                              path_dict,
                              writer=writer,
                              epoch=0,
                              mode='test',
                              batches_per_ep=len(test_loader))

        checkpoint['test_result'] = test_result
        if args['save_results_here']:

            # Ensure the directory exists
            os.makedirs(os.path.dirname(args['save_results_here']),
                        exist_ok=True)

            # Save out test results here instead
            with open(args['save_results_here'], 'wb') as f:
                pickle.dump(checkpoint, f)
        else:

            # Ensure the directory exists
            os.makedirs(path_dict['results'], exist_ok=True)

            # Save out the test results
            with open(path_dict['results'] + '/test_results.pkl', 'wb') as f:
                pickle.dump(checkpoint, f)

    else:
        spiker = SpikeDetection() if args['remove_spikes'] else False
        logging.info('Entering train mode ...')

        epoch = 0

        # Disable early stop and keep training until it maxes out, this allows
        # us to test at the regular best model while saving intermediate result
        #while (epoch < args['epochs']) and not early_stop.early_stop:

        while (epoch < args['epochs']):
            if args['time_to_update']:

                # Toggle flag back to False
                args['time_to_update'] = False

                if args['one_by_one_ds']:
                    train_loader.dataset.sort('one_by_one_ds',
                                              args['batch_size'])
                    valid_loader.dataset.sort('one_by_one_ds',
                                              args['batch_size'])
                else:
                    train_loader.dataset.sort('mutliset_random')
                    valid_loader.dataset.sort('mutliset_random')

            # Set epochs for samplers
            train_sampler.set_epoch(epoch)
            valid_sampler.set_epoch(epoch)

            # %%
            logging.info('Starting epoch: %d' % epoch)

            if args['curr_learn_losses']:
                args['alpha'] = alpha_scalar.get_scalar(epoch)
                args['beta'] = beta_scalar.get_scalar(epoch)
            else:
                args['alpha'] = 0.5
                args['beta'] = 0.5

            if args['dry_run']:
                train_batches_per_ep = len(train_loader)
                valid_batches_per_ep = len(valid_loader)
            else:
                train_batches_per_ep = args['batches_per_ep']
                if args['reduce_valid_samples']:
                    valid_batches_per_ep = args['reduce_valid_samples']
                else:
                    valid_batches_per_ep = len(valid_loader)

            train_result = forward(net,
                                   spiker,
                                   logger,
                                   train_loader,
                                   optimizer,
                                   args,
                                   path_dict,
                                   optimizer_disc=optimizer_disc,
                                   writer=writer,
                                   epoch=epoch,
                                   mode='train',
                                   batches_per_ep=train_batches_per_ep)

            if args['reduce_valid_samples']:
                valid_result = forward(net,
                                       spiker,
                                       logger,
                                       valid_loader,
                                       optimizer,
                                       args,
                                       path_dict,
                                       writer=writer,
                                       epoch=epoch,
                                       mode='valid',
                                       batches_per_ep=valid_batches_per_ep)
            else:
                valid_result = forward(net,
                                       spiker,
                                       logger,
                                       valid_loader,
                                       optimizer,
                                       args,
                                       path_dict,
                                       writer=writer,
                                       epoch=epoch,
                                       mode='valid',
                                       batches_per_ep=len(valid_loader))

            # Update the check point weights. VERY IMPORTANT!
            checkpoint['state_dict'] = move_to_single(net.state_dict())

            checkpoint['epoch'] = epoch
            checkpoint['train_result'] = train_result
            checkpoint['valid_result'] = valid_result

            # Save out the best validation result and model
            early_stop(checkpoint)

            # If epoch is a multiple of args['save_every'], then write out
            if (epoch % args['save_every']) == 0:

                # Ensure that you do not update the validation score at this
                # point and simply save the model
                early_stop.save_checkpoint(
                    checkpoint['valid_result']['score_mean'],
                    checkpoint,
                    update_val_score=False,
                    use_this_name_instead='{}.pt'.format(epoch))

            epoch += 1
def train(args):
    # initialize Horovod library
    hvd.init()
    # Horovod limits CPU threads to be used per worker
    torch.set_num_threads(1)
    # disable logging for processes except 0 on every node
    if hvd.local_rank() != 0:
        f = open(os.devnull, "w")
        sys.stdout = sys.stderr = f
    elif not os.path.exists(args.dir):
        # create 40 random image, mask paris on master node for training
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(40):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["img", "seg"]),
    ])

    # create a training data loader
    train_ds = Dataset(data=train_files, transform=train_transforms)
    # create a training data sampler
    train_sampler = DistributedSampler(train_ds,
                                       num_replicas=hvd.size(),
                                       rank=hvd.rank())
    # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent
    # issues with Infiniband implementations that are not fork-safe
    multiprocessing_context = None
    if hasattr(
            mp, "_supports_context"
    ) and mp._supports_context and "forkserver" in mp.get_all_start_methods():
        multiprocessing_context = "forkserver"
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        sampler=train_sampler,
        multiprocessing_context=multiprocessing_context,
    )

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{hvd.local_rank()}")
    torch.cuda.set_device(device)
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    # Horovod broadcasts parameters & optimizer state
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)
    # Horovod wraps optimizer with DistributedOptimizer
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())

    # start a typical PyTorch training
    epoch_loss_values = list()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        train_sampler.set_epoch(epoch)
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    print(f"train completed, epoch losses: {epoch_loss_values}")
    if hvd.rank() == 0:
        # all processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes,
        # therefore, saving it in one process is sufficient
        torch.save(model.state_dict(), "final_model.pth")
def main():
    global args, best_prec1, min_loss
    args = parser.parse_args()

    rank, world_size = dist_init(args.port)
    print("world_size is: {}".format(world_size))
    assert (args.batch_size % world_size == 0)
    assert (args.workers % world_size == 0)
    args.batch_size = args.batch_size // world_size
    args.workers = args.workers // world_size

    # create model
    print("=> creating model '{}'".format("inceptionv4"))
    print("save_path is: {}".format(args.save_path))

    image_size = 341
    input_size = 299
    model = get_model('inceptionv4', pretrained=True)
    # print("model is: {}".format(model))
    model.cuda()
    model = DistModule(model)

    # optionally resume from a checkpoint
    if args.load_path:
        if args.resume_opt:
            best_prec1, start_epoch = load_state(args.load_path,
                                                 model,
                                                 optimizer=optimizer)
        else:
            # print('load weights from', args.load_path)
            load_state(args.load_path, model)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_dataset = McDataset(
        args.train_root, args.train_source,
        transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ColorAugmentation(),
            normalize,
        ]))
    val_dataset = McDataset(
        args.val_root, args.val_source,
        transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=False,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=False,
                            sampler=val_sampler)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    lr = 0
    patience = 0
    for epoch in range(args.start_epoch, args.epochs):
        # adjust_learning_rate(optimizer, epoch)
        train_sampler.set_epoch(epoch)

        if epoch == 1:
            lr = 0.00003
        if patience == 2:
            patience = 0
            checkpoint = load_checkpoint(args.save_path + '_best.pth.tar')
            model.load_state_dict(checkpoint['state_dict'])
            print("Loading checkpoint_best.............")
            # model.load_state_dict(torch.load('checkpoint_best.pth.tar'))
            lr = lr / 10.0

        if epoch == 0:
            lr = 0.001
            for name, param in model.named_parameters():
                # print("name is: {}".format(name))
                if (name not in last_layer_names):
                    param.requires_grad = False
            optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                   model.parameters()),
                                            lr=lr)
            # optimizer = torch.optim.Adam(
            #     filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
        else:
            for param in model.parameters():
                param.requires_grad = True
            optimizer = torch.optim.RMSprop(model.parameters(),
                                            lr=lr,
                                            weight_decay=0.0001)
            # optimizer = torch.optim.Adam(
            #     model.parameters(), lr=lr, weight_decay=0.0001)
        print("lr is: {}".format(lr))
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        val_prec1, val_losses = validate(val_loader, model, criterion)
        print("val_losses is: {}".format(val_losses))
        # remember best prec@1 and save checkpoint
        if rank == 0:
            # remember best prec@1 and save checkpoint
            if val_losses < min_loss:
                is_best = True
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': 'inceptionv4',
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer': optimizer.state_dict(),
                    }, is_best, args.save_path)
                # torch.save(model.state_dict(), 'best_val_weight.pth')
                print(
                    'val score improved from {:.5f} to {:.5f}. Saved!'.format(
                        min_loss, val_losses))

                min_loss = val_losses
                patience = 0
            else:
                patience += 1
        if rank == 1 or rank == 2 or rank == 3 or rank == 4 or rank == 5 or rank == 6 or rank == 7:
            if val_losses < min_loss:
                min_loss = val_losses
                patience = 0
            else:
                patience += 1
        print("patience is: {}".format(patience))
        print("min_loss is: {}".format(min_loss))
    print("min_loss is: {}".format(min_loss))
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for key in config:
        for k, v in config[key].items():
            setattr(args, k, v)

    print('Enabled distributed training.')

    rank, world_size = init_dist(
        backend='nccl', port=args.port)
    args.rank = rank
    args.world_size = world_size


    np.random.seed(args.seed*args.rank)
    torch.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed(args.seed*args.rank)
    torch.cuda.manual_seed_all(args.seed*args.rank)

    # create model
    print("=> creating model '{}'".format(args.model))
    if args.SinglePath:
        architecture = 20*[0]
        channels_scales = 20*[1.0]
        #load derived child network
        log_alpha = torch.load(args.checkpoint_path, map_location='cuda:{}'.format(torch.cuda.current_device()))['state_dict']['log_alpha']
        weights = torch.zeros_like(log_alpha).scatter_(1, torch.argmax(log_alpha, dim = -1).view(-1,1), 1)
        model = ShuffleNetV2_OneShot(args=args, architecture=architecture, channels_scales=channels_scales, weights=weights)
        model.cuda()
        broadcast_params(model)
        for v in model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        model.log_alpha.grad = torch.zeros_like(model.log_alpha)   
        if not args.retrain:
            load_state_ckpt(args.checkpoint_path, model)
            checkpoint = torch.load(args.checkpoint_path, map_location='cuda:{}'.format(torch.cuda.current_device()))
            args.base_lr = checkpoint['optimizer']['param_groups'][0]['lr']
        if args.reset_bn_stat:
            model._reset_bn_running_stats()

    # define loss function (criterion) and optimizer
    criterion = CrossEntropyLoss(smooth_eps=0.1, smooth_dist=(torch.ones(1000)*0.001).cuda()).cuda()

    wo_wd_params = []
    wo_wd_param_names = []
    network_params = []
    network_param_names = []

    for name, mod in model.named_modules():
        #if isinstance(mod, (nn.BatchNorm2d, SwitchNorm2d)):
        if isinstance(mod, nn.BatchNorm2d):
            for key, value in mod.named_parameters():
                wo_wd_param_names.append(name+'.'+key)
        
    for key, value in model.named_parameters():
        if key != 'log_alpha':
            if value.requires_grad:
                if key in wo_wd_param_names:
                    wo_wd_params.append(value)
                else:
                    network_params.append(value)
                    network_param_names.append(key)

    params = [
        {'params': network_params,
         'lr': args.base_lr,
         'weight_decay': args.weight_decay },
        {'params': wo_wd_params,
         'lr': args.base_lr,
         'weight_decay': 0.},
    ]
    param_names = [network_param_names, wo_wd_param_names]
    if args.rank == 0:
        print('>>> params w/o weight decay: ', wo_wd_param_names)
    optimizer = torch.optim.SGD(params, momentum=args.momentum)
    arch_optimizer=None

    # auto resume from a checkpoint
    remark = 'imagenet_'
    remark += 'epo_' + str(args.epochs) + '_layer_' + str(args.layers) + '_batch_' + str(args.batch_size) + '_lr_' + str(float("{0:.2f}".format(args.base_    lr))) + '_seed_' + str(args.seed)

    if args.remark != 'none':
        remark += '_'+args.remark

    args.save = 'search-{}-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"), remark)
    args.save_log = 'nas-{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), remark)
    generate_date = str(datetime.now().date())

    path = os.path.join(generate_date, args.save)
    if args.rank == 0:
        log_format = '%(asctime)s %(message)s'
        utils.create_exp_dir(generate_date, path, scripts_to_save=glob.glob('*.py'))
        logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                            format=log_format, datefmt='%m/%d %I:%M:%S %p')
        fh = logging.FileHandler(os.path.join(path, 'log.txt'))
        fh.setFormatter(logging.Formatter(log_format))
        logging.getLogger().addHandler(fh)
        logging.info("args = %s", args)
        writer = SummaryWriter('./runs/' + generate_date + '/' + args.save_log)
    else:
        writer = None

    #model_dir = args.model_dir
    model_dir = path
    start_epoch = 0
    
    if args.evaluate:
        load_state_ckpt(args.checkpoint_path, model)
    else:
        best_prec1, start_epoch = load_state(model_dir, model, optimizer=optimizer)

    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_dataset_wo_ms = ImagenetDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    val_dataset = ImagenetDataset(
        args.val_root,
        args.val_source,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    train_loader_wo_ms = DataLoader(
        train_dataset_wo_ms, batch_size=args.batch_size//args.world_size, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=train_sampler)

    val_loader = DataLoader(
        val_dataset, batch_size=50, shuffle=False,
        num_workers=args.workers, pin_memory=False, sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion, 0, writer, logging)
        return

    niters = len(train_loader)

    lr_scheduler = LRScheduler(optimizer, niters, args)

    for epoch in range(start_epoch, args.epochs):
        train_sampler.set_epoch(epoch)
        
        if args.rank == 0 and args.SinglePath:
            logging.info('epoch %d', epoch)
        
        # evaluate on validation set after loading the model
        if epoch == 0 and not args.reset_bn_stat:
            prec1 = validate(val_loader, model, criterion, epoch, writer, logging)
       
         # train for one epoch
        if epoch >= args.epochs - 5 and args.lr_mode == 'step' and args.off_ms and args.retrain:
            train(train_loader_wo_ms, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)
        else:
            train(train_loader, model, criterion, optimizer, arch_optimizer, lr_scheduler, epoch, writer, logging)


        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer, logging)

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(model_dir, {
                'epoch': epoch + 1,
                'model': args.model,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best)