示例#1
0
def main():
    global args, rank, world_size
    if args.dist == 1:
        rank, world_size = dist_init()
    else:
        rank = 0
        world_size = 1

    DATA_DIR = './data'

    train_set_raw = torchvision.datasets.CIFAR10(root=DATA_DIR,
                                                 train=True,
                                                 download=True)
    test_set_raw = torchvision.datasets.CIFAR10(root=DATA_DIR,
                                                train=False,
                                                download=True)

    lr_schedule = PiecewiseLinear([0, 5, 24], [0, 0.4 * args.lr_scale, 0])
    train_transforms = [Crop(32, 32), FlipLR(), Cutout(8, 8)]

    model = TorchGraph(union(net(), losses)).cuda()
    if args.half == 1:
        model = model.half()
    if args.double == 1:
        model = model.double()
    if args.dist == 1:
        model = DistModule(model)
    opt = torch.optim.SGD(model.parameters(),
                          lr=0.0,
                          momentum=args.momentum,
                          weight_decay=5e-4 * args.batch_size,
                          nesterov=True)

    t = Timer()

    train_set = list(
        zip(transpose(normalise(pad(train_set_raw.data, 4))),
            train_set_raw.targets))
    test_set = list(
        zip(transpose(normalise(test_set_raw.data)), test_set_raw.targets))
    dataset_len = len(train_set)
    args.warm_up_iter = math.ceil(dataset_len * args.warm_up_epoch /
                                  (world_size * args.batch_size))

    TSV = TSVLogger()
    train(model,
          lr_schedule,
          opt,
          Transform(train_set, train_transforms),
          test_set,
          args=args,
          batch_size=args.batch_size,
          num_workers=args.workers,
          loggers=(TableLogger(rank), TSV),
          timer=t,
          test_time_in_total=False,
          drop_last=True)
def main():
    global args, best_prec1, min_loss
    args = parser.parse_args()

    rank, world_size = dist_init(args.port)
    print("world_size is: {}".format(world_size))
    assert (args.batch_size % world_size == 0)
    assert (args.workers % world_size == 0)
    args.batch_size = args.batch_size // world_size
    args.workers = args.workers // world_size

    # create model
    print("=> creating model '{}'".format("inceptionv4"))
    print("save_path is: {}".format(args.save_path))

    image_size = 341
    input_size = 299
    model = get_model('inceptionv4', pretrained=True)
    # print("model is: {}".format(model))
    model.cuda()
    model = DistModule(model)

    # optionally resume from a checkpoint
    if args.load_path:
        if args.resume_opt:
            best_prec1, start_epoch = load_state(args.load_path,
                                                 model,
                                                 optimizer=optimizer)
        else:
            # print('load weights from', args.load_path)
            load_state(args.load_path, model)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_dataset = McDataset(
        args.train_root, args.train_source,
        transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            ColorAugmentation(),
            normalize,
        ]))
    val_dataset = McDataset(
        args.val_root, args.val_source,
        transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=False,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=False,
                            sampler=val_sampler)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    lr = 0
    patience = 0
    for epoch in range(args.start_epoch, args.epochs):
        # adjust_learning_rate(optimizer, epoch)
        train_sampler.set_epoch(epoch)

        if epoch == 1:
            lr = 0.00003
        if patience == 2:
            patience = 0
            checkpoint = load_checkpoint(args.save_path + '_best.pth.tar')
            model.load_state_dict(checkpoint['state_dict'])
            print("Loading checkpoint_best.............")
            # model.load_state_dict(torch.load('checkpoint_best.pth.tar'))
            lr = lr / 10.0

        if epoch == 0:
            lr = 0.001
            for name, param in model.named_parameters():
                # print("name is: {}".format(name))
                if (name not in last_layer_names):
                    param.requires_grad = False
            optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                                   model.parameters()),
                                            lr=lr)
            # optimizer = torch.optim.Adam(
            #     filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
        else:
            for param in model.parameters():
                param.requires_grad = True
            optimizer = torch.optim.RMSprop(model.parameters(),
                                            lr=lr,
                                            weight_decay=0.0001)
            # optimizer = torch.optim.Adam(
            #     model.parameters(), lr=lr, weight_decay=0.0001)
        print("lr is: {}".format(lr))
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        val_prec1, val_losses = validate(val_loader, model, criterion)
        print("val_losses is: {}".format(val_losses))
        # remember best prec@1 and save checkpoint
        if rank == 0:
            # remember best prec@1 and save checkpoint
            if val_losses < min_loss:
                is_best = True
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': 'inceptionv4',
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'optimizer': optimizer.state_dict(),
                    }, is_best, args.save_path)
                # torch.save(model.state_dict(), 'best_val_weight.pth')
                print(
                    'val score improved from {:.5f} to {:.5f}. Saved!'.format(
                        min_loss, val_losses))

                min_loss = val_losses
                patience = 0
            else:
                patience += 1
        if rank == 1 or rank == 2 or rank == 3 or rank == 4 or rank == 5 or rank == 6 or rank == 7:
            if val_losses < min_loss:
                min_loss = val_losses
                patience = 0
            else:
                patience += 1
        print("patience is: {}".format(patience))
        print("min_loss is: {}".format(min_loss))
    print("min_loss is: {}".format(min_loss))
def main():
    global args
    args = parser.parse_args()

    # TODO model arguments module should be more easy to write and read
    if args.approach == 'lwf':
        approach = lwf
        assert (args.memory_size is None)
        assert (args.memory_mini_batch_size is None)
    elif args.approach == 'joint_train':
        approach = joint_train
        assert (args.memory_size is None)
        assert (args.memory_mini_batch_size is None)
    elif args.approach == 'fine_tuning':
        approach = fine_tuning
        assert (args.memory_size is None)
        assert (args.memory_mini_batch_size is None)
    elif args.approach == 'gem':
        approach = gem
        assert (args.memory_size is not None)
        assert (args.memory_mini_batch_size is None)
    else:
        approach = None

    rank, world_size = dist_init('27777')

    if rank == 0:
        print('=' * 100)
        print('Arguments = ')
        for arg in vars(args):
            print('\t' + arg + ':', getattr(args, arg))
        print('=' * 100)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed)
    else:
        print('[CUDA unavailable]')
        sys.exit()

    # Generate Tasks
    args.batch_size = args.batch_size // world_size
    Tasks = generator.GetTasks(args.approach, args.batch_size, world_size, \
        memory_size=args.memory_size, memory_mini_batch_size=args.memory_mini_batch_size)
    # Network
    net = network.resnet50(pretrained=True).cuda()
    net = DistModule(net)
    # Approach
    Appr = approach.Approach(net, args, Tasks)

    # Solve tasks incrementally
    for t in range(len(Tasks)):
        task = Tasks[t]

        if rank == 0:
            print('*' * 100)
            print()
            print('Task {:d}: {:d} classes ({:s})'.format(
                t, task['class_num'], task['description']))
            print()
            print('*' * 100)

        Appr.solve(t, Tasks)

        if rank == 0:
            print('*' * 100)
            print('Task {:d}: {:d} classes Finished.'.format(
                t, task['class_num']))
            print('*' * 100)
示例#4
0
def main():
    global args, config, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    config = EasyDict(config['common'])
    config.save_path = os.path.dirname(args.config)

    rank, world_size = dist_init()

    # create model
    bn_group_size = config.model.kwargs.bn_group_size
    bn_var_mode = config.model.kwargs.get('bn_var_mode', 'L2')
    if bn_group_size == 1:
        bn_group = None
    else:
        assert world_size % bn_group_size == 0
        bn_group = simple_group_split(world_size, rank,
                                      world_size // bn_group_size)

    config.model.kwargs.bn_group = bn_group
    config.model.kwargs.bn_var_mode = (link.syncbnVarMode_t.L1 if bn_var_mode
                                       == 'L1' else link.syncbnVarMode_t.L2)
    model = model_entry(config.model)
    if rank == 0:
        print(model)

    model.cuda()

    if config.optimizer.type == 'FP16SGD' or config.optimizer.type == 'FusedFP16SGD':
        args.fp16 = True
    else:
        args.fp16 = False

    if args.fp16:
        # if you have modules that must use fp32 parameters, and need fp32 input
        # try use link.fp16.register_float_module(your_module)
        # if you only need fp32 parameters set cast_args=False when call this
        # function, then call link.fp16.init() before call model.half()
        if config.optimizer.get('fp16_normal_bn', False):
            print('using normal bn for fp16')
            link.fp16.register_float_module(link.nn.SyncBatchNorm2d,
                                            cast_args=False)
            link.fp16.register_float_module(torch.nn.BatchNorm2d,
                                            cast_args=False)
            link.fp16.init()
        model.half()

    model = DistModule(model, args.sync)

    # create optimizer
    opt_config = config.optimizer
    opt_config.kwargs.lr = config.lr_scheduler.base_lr
    if config.get('no_wd', False):
        param_group, type2num = param_group_no_wd(model)
        opt_config.kwargs.params = param_group
    else:
        opt_config.kwargs.params = model.parameters()

    optimizer = optim_entry(opt_config)

    # optionally resume from a checkpoint
    last_iter = -1
    best_prec1 = 0
    if args.load_path:
        if args.recover:
            best_prec1, last_iter = load_state(args.load_path,
                                               model,
                                               optimizer=optimizer)
        else:
            load_state(args.load_path, model)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    # augmentation
    aug = [
        transforms.RandomResizedCrop(config.augmentation.input_size),
        transforms.RandomHorizontalFlip()
    ]

    for k in config.augmentation.keys():
        assert k in [
            'input_size', 'test_resize', 'rotation', 'colorjitter', 'colorold'
        ]
    rotation = config.augmentation.get('rotation', 0)
    colorjitter = config.augmentation.get('colorjitter', None)
    colorold = config.augmentation.get('colorold', False)

    if rotation > 0:
        aug.append(transforms.RandomRotation(rotation))

    if colorjitter is not None:
        aug.append(transforms.ColorJitter(*colorjitter))

    aug.append(transforms.ToTensor())

    if colorold:
        aug.append(ColorAugmentation())

    aug.append(normalize)

    # train
    train_dataset = McDataset(config.train_root,
                              config.train_source,
                              transforms.Compose(aug),
                              fake=args.fake)

    # val
    val_dataset = McDataset(
        config.val_root, config.val_source,
        transforms.Compose([
            transforms.Resize(config.augmentation.test_resize),
            transforms.CenterCrop(config.augmentation.input_size),
            transforms.ToTensor(),
            normalize,
        ]), args.fake)

    train_sampler = DistributedGivenIterationSampler(
        train_dataset,
        config.lr_scheduler.max_iter,
        config.batch_size,
        last_iter=last_iter)
    val_sampler = DistributedSampler(val_dataset, round_up=False)

    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=False,
                              num_workers=config.workers,
                              pin_memory=True,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.workers,
                            pin_memory=True,
                            sampler=val_sampler)

    config.lr_scheduler['optimizer'] = optimizer.optimizer if isinstance(
        optimizer, FP16SGD) else optimizer
    config.lr_scheduler['last_iter'] = last_iter
    lr_scheduler = get_scheduler(config.lr_scheduler)

    if rank == 0:
        tb_logger = SummaryWriter(config.save_path + '/events')
        logger = create_logger('global_logger', config.save_path + '/log.txt')
        logger.info('args: {}'.format(pprint.pformat(args)))
        logger.info('config: {}'.format(pprint.pformat(config)))
    else:
        tb_logger = None

    if args.evaluate:
        if args.fusion_list is not None:
            validate(val_loader,
                     model,
                     fusion_list=args.fusion_list,
                     fuse_prob=args.fuse_prob)
        else:
            validate(val_loader, model)
        link.finalize()
        return

    train(train_loader, val_loader, model, optimizer, lr_scheduler,
          last_iter + 1, tb_logger)

    link.finalize()
def main():
    global args, rank, world_size, best_prec1, dataset_len

    if args.dist == 1:
        rank, world_size = dist_init()
    else:
        rank = 0
        world_size = 1

    model = LeNet()
    model.cuda()
    if args.double == 1:
        param_copy = [
            param.clone().type(torch.cuda.DoubleTensor).detach()
            for param in model.parameters()
        ]
    else:
        param_copy = [
            param.clone().type(torch.cuda.FloatTensor).detach()
            for param in model.parameters()
        ]

    for param in param_copy:
        param.requires_grad = True

    if args.double == 1:
        model = model.double()
    if args.half == 1:
        model = model.half()
    if args.dist == 1:
        model = DistModule(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(param_copy,
                                args.base_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    last_iter = -1

    # Data loading code
    train_dataset = datasets.MNIST(root='./data',
                                   train=True,
                                   transform=transforms.ToTensor(),
                                   download=False)
    val_dataset = datasets.MNIST(root='./data',
                                 train=False,
                                 transform=transforms.ToTensor(),
                                 download=False)

    dataset_len = len(train_dataset)
    args.max_iter = math.ceil(
        (dataset_len * args.epoch) / (world_size * args.batch_size))

    if args.dist == 1:
        train_sampler = DistributedGivenIterationSampler(train_dataset,
                                                         args.max_iter,
                                                         args.batch_size,
                                                         last_iter=last_iter)
        val_sampler = DistributedSampler(val_dataset, round_up=False)
    else:
        train_sampler = DistributedGivenIterationSampler(train_dataset,
                                                         args.max_iter,
                                                         args.batch_size,
                                                         world_size=1,
                                                         rank=0,
                                                         last_iter=last_iter)
        val_sampler = None

    # pin_memory if true, will copy the tensor to cuda pinned memory
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=True,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True,
                            sampler=val_sampler)

    train(train_loader, val_loader, model, criterion, optimizer, param_copy)
示例#6
0
def main():
    global args, best_prec1, timer
    args = parser.parse_args()
    rank, world_size = dist_init(args.port)
    assert (args.batch_size % world_size == 0)
    assert (args.workers % world_size == 0)
    args.batch_size = args.batch_size // world_size
    args.workers = args.workers // world_size

    # step1: create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch.startswith('inception_v3'):
        print('inception_v3 without aux_logits!')
        image_size = 341
        input_size = 299
        model = models.__dict__[args.arch](aux_logits=False)
    elif args.arch.startswith('ir18'):
        image_size = 640
        input_size = 448
        model = IR18()
    else:
        image_size = 256
        input_size = 224
        model = models.__dict__[args.arch]()

    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        if os.path.isfile(args.pretrained):
            print("=> loading pretrained_model '{}'".format(args.pretrained))
            pretrained_model = torch.load(args.pretrained)
            model.load_state_dict(pretrained_model['state_dict'], strict=False)
            print("=> loaded pretrained_model '{}'".format(args.pretrained))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrained))
    model.cuda()
    model = DistModule(model)

    # step2: define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # step3: Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = McDataset(
        args.train_root,
        args.train_source,
        transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # ColorAugmentation(),
            # normalize,
        ]))
    val_dataset = McDataset(
        args.val_root,
        args.val_source,
        transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            # normalize,
        ]))

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=False,
                              sampler=train_sampler)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=False,
                            sampler=val_sampler)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return
    timer = Timer(
        len(train_loader) + len(val_loader), args.epochs - args.start_epoch)
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)
        train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        if rank == 0:
            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, args.save_path)
            print('* Best Prec 1: {best:.3f}'.format(best=best_prec1))
            f11.close()
            acclist.append(0)
            return 0, current_ep
        test_acc, best_acc = test(ep)
        logger.debug(test_acc)
        if early_stop.step(test_acc):
            break
    list = [best_acc, bs_explore, str(lr_explore)[0:7]]
    reslist.append(list)
    acclist.append(best_acc)
    return best_acc, current_ep


if __name__ == "__main__":
    args = get_args()
    rank, world_size = dist_init(args.port)
    if rank == 1:
        f11 = open('/root/rank' + str(rank), 'a+')
        f11.write('rank:' + str(rank) + "\n")
        f11.write("world_size:" + str(world_size) + "\n")
        f11.close()
    example_start_time = time.time()
    try:
        real_model_file = os.path.join("/root", "real_model.json")
        experiment_path = os.environ[
            "HOME"] + "/mountdir/nni/experiments/" + str(
                nni.get_experiment_id())
        assert (args.workers % world_size == 0)
        args.workers = args.workers // world_size
        #real_model_file = os.path.join(trial_env_vars.NNI_SYS_DIR, "real_model.json")
        if rank == 0:  # only works for single node
示例#8
0
def main():
    global args, config, best_loss
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for k, v in config['common'].items():
        setattr(args, k, v)
    config = EasyDict(config['common'])

    rank, world_size, device_id = dist_init(
        os.path.join(args.distributed_path, config.distributed_file))

    args.save_path_dated = args.save_path + '/' + args.datetime
    if args.run_tag != '':
        args.save_path_dated += '-' + args.run_tag

    # create model
    model = model_entry(config.model)
    model.cuda()

    model = nn.parallel.DistributedDataParallel(model, device_ids=[device_id])

    # create optimizer
    opt_config = config.optimizer
    opt_config.kwargs.lr = config.lr_scheduler.base_lr
    opt_config.kwargs.params = model.parameters()

    optimizer = optim_entry(opt_config)

    # optionally resume from a checkpoint
    last_iter = -1
    best_loss = 1e9
    if args.load_path:
        if args.recover:
            best_loss, last_iter = load_state(args.load_path,
                                              model,
                                              optimizer=optimizer)
        else:
            load_state(args.load_path, model)

    cudnn.benchmark = True

    # train augmentation
    if config.augmentation.get('imgnet_mean', False):
        model_mean = (0.485, 0.456, 0.406)
        model_std = (0.229, 0.224, 0.225)
    else:
        model_mean = (0.5, 0.5, 0.5)
        model_std = (0.5, 0.5, 0.5)
    trans = albumentations.Compose([
        RandomResizedCrop(config.augmentation.input_size,
                          config.augmentation.input_size,
                          scale=(config.augmentation.min_scale**2., 1.),
                          ratio=(1., 1.)),
        HorizontalFlip(p=0.5),
        RandomBrightnessContrast(brightness_limit=0.25,
                                 contrast_limit=0.1,
                                 p=0.5),
        JpegCompression(p=.2, quality_lower=50),
        MotionBlur(p=0.5),
        Normalize(mean=model_mean, std=model_std),
        ToTensorV2()
    ])

    train_dataset = FaceDataset(config.train_root,
                                config.train_source,
                                transform=trans,
                                resize=config.augmentation.input_size,
                                image_format=config.get('image_format', None),
                                random_frame=config.get(
                                    'train_random_frame', False),
                                bgr=config.augmentation.get('bgr', False))

    train_sampler = DistributedGivenIterationSampler(
        train_dataset,
        config.lr_scheduler.max_iter,
        config.batch_size,
        last_iter=last_iter)
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=False,
                              num_workers=config.workers,
                              pin_memory=True,
                              sampler=train_sampler)

    # validation augmentation
    trans = albumentations.Compose([
        Resize(config.augmentation.input_size, config.augmentation.input_size),
        Normalize(mean=model_mean, std=model_std),
        ToTensorV2()
    ])
    val_multi_loader = []
    if args.val_source != '':
        for dataset_idx in range(len(args.val_source)):
            val_dataset = FaceDataset(
                args.val_root[dataset_idx],
                args.val_source[dataset_idx],
                transform=trans,
                output_index=True,
                resize=config.augmentation.input_size,
                image_format=config.get('image_format', None),
                bgr=config.augmentation.get('bgr', False))
            val_sampler = DistributedSampler(val_dataset, round_up=False)
            val_loader = DataLoader(val_dataset,
                                    batch_size=config.batch_size,
                                    shuffle=False,
                                    num_workers=config.workers,
                                    pin_memory=True,
                                    sampler=val_sampler)
            val_multi_loader.append(val_loader)

    config.lr_scheduler['optimizer'] = optimizer
    config.lr_scheduler['last_iter'] = last_iter
    lr_scheduler = get_scheduler(config.lr_scheduler)

    if rank == 0:
        mkdir(args.save_path)

        mkdir(args.save_path_dated)
        tb_logger = SummaryWriter(args.save_path_dated)

        logger = create_logger('global_logger',
                               args.save_path_dated + '-log.txt')
        logger.info('{}'.format(args))
        logger.info(model)
        logger.info(parameters_string(model))
        logger.info('len(train dataset) = %d' % len(train_loader.dataset))
        for dataset_idx in range(len(val_multi_loader)):
            logger.info(
                'len(val%d dataset) = %d' %
                (dataset_idx, len(val_multi_loader[dataset_idx].dataset)))

        mkdir(args.save_path_dated + '/saves')
    else:
        tb_logger = None

    positive_weight = config.get('positive_weight', 0.5)
    weight = torch.tensor([1. - positive_weight, positive_weight]) * 2.
    if rank == 0:
        logger.info('using class weights: {}'.format(weight.tolist()))

    criterion = nn.CrossEntropyLoss(weight=weight).cuda()

    if args.evaluate:
        if args.evaluate_path:
            all_ckpt = get_all_checkpoint(args.evaluate_path, args.range_list,
                                          rank)

            for ckpt in all_ckpt:
                if rank == 0:
                    logger.info('Testing ckpt: ' + ckpt)
                last_iter = -1
                _, last_iter = load_state(ckpt, model, optimizer=optimizer)
                for dataset_idx in range(len(val_multi_loader)):
                    validate(dataset_idx,
                             val_multi_loader[dataset_idx],
                             model,
                             criterion,
                             tb_logger,
                             curr_step=last_iter,
                             save_softmax=True)
        else:
            for dataset_idx in range(len(val_multi_loader)):
                validate(dataset_idx,
                         val_multi_loader[dataset_idx],
                         model,
                         criterion,
                         tb_logger,
                         curr_step=last_iter,
                         save_softmax=True)

        return

    train(train_loader, val_multi_loader, model, criterion, optimizer,
          lr_scheduler, last_iter + 1, tb_logger)
    return