Ejemplo n.º 1
0
def main():
    # parse arg and start experiment
    global args
    best_acc = 0.
    best_epoch = 0

    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']

    # limit the gpu id to use
    # WARNING: This assignment should be down at the beginning, in case of different assignment for different parts
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    gpu_id = []
    for i in range(len(args.gpu_id.split(','))):
        gpu_id.append(i)

    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tesnorboard_logger installed' +
              Fore.RESET)

    # optionally resume from a checkpoint
    if args.resume:
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            old_args = checkpoint['args']
            print('Old args:')
            print(old_args)
            # set args based on checkpoint
            if args.start_epoch <= 0:
                args.start_epoch = checkpoint['epoch'] + 1
            best_epoch = args.start_epoch - 1
            print('Epoch recovered:%d' % checkpoint['epoch'])
            best_acc = checkpoint['best_acc']
            for name in arch_resume_names:
                if name in vars(args) and name in vars(old_args):
                    setattr(args, name, getattr(old_args, name))

            model = getModel(**vars(args))

            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'"
                  .format(args.resume))
        else:
            print(
                "=> no checkpoint found at '{}'".format(
                    Fore.RED +
                    args.resume +
                    Fore.RESET),
                file=sys.stderr)
            return
    elif args.pretrain:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = getModel(**vars(args))
        model = load_pretrained_diff_parameter(model, args.pretrain)

        print("=> pre-train weights loaded")
    else:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = getModel(**vars(args))

    model = torch.nn.DataParallel(model, device_ids=gpu_id).cuda()

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # define optimizer
    optimizer = get_optimizer(model, args)

    # set random seed
    torch.manual_seed(args.seed)

    Trainer = import_module(args.trainer).Trainer
    trainer = Trainer(model, criterion, optimizer, args)

    # create dataloader
    if args.eval == 'train':
        train_loader, _, _ = getDataloaders(args.data,
                                            splits=('train'), batch_size=args.batch_size)
        trainer.test(train_loader, best_epoch)
        return
    elif args.eval == 'val':
        _, val_loader, _ = getDataloaders(args.data,
                                          splits=('val'), batch_size=args.batch_size)
        trainer.test(val_loader, best_epoch)
        return
    elif args.eval == 'test':
        _, _, test_loader = getDataloaders(args.data, splits=('test'), batch_size=args.batch_size)
        trainer.test(test_loader, best_epoch)
        return
    else:
        train_loader, val_loader, _ = getDataloaders(args.data,
                                                     splits=('train', 'val'),
                                                     batch_size=args.batch_size)

    # check if the folder exists
    create_save_folder(args.save, args.force)

    # set up logging
    global log_print, f_log
    f_log = open(os.path.join(args.save, 'log.txt'), 'w')

    def log_print(*args):
        print(*args)
        print(*args, file=f_log)

    log_print('args:')
    log_print(args)
    # print('model:', file=f_log)
    # print(model, file=f_log)
    f_log.flush()
    log_print('# of params:',
              str(sum([p.numel() for p in model.parameters()])))

    torch.save(args, os.path.join(args.save, 'args.pth'))
    if args.tensorboard:
        configure(args.save, flush_secs=5)

    for epoch in range(args.start_epoch, args.epochs + 1):

        # train for one epoch
        train_loss, train_acc, lr = trainer.train(
            train_loader, epoch)

        if args.tensorboard:
            log_value('lr', lr, epoch)
            log_value('train_loss', train_loss, epoch)
            log_value('train_acc', train_acc, epoch)

        # evaluate on validation set
        val_loss, val_acc, recall, precision, f1, acc = trainer.test(val_loader, epoch, silence=True)

        if args.tensorboard:
            log_value('val_loss', val_loss, epoch)
            log_value('val_acc', val_acc, epoch)
            # log recall, precision and f1 value for every class
            # labels should be sequential natural numbers like 0,1,2....
            for i in range(args.num_classes):
                try:
                    log_value('cls_' + str(i) + '_recall', recall[i], epoch)
                except:
                    log_value('cls_' + str(i) + '_recall', 0, epoch)
                try:
                    log_value('cls_' + str(i) + '_precision', precision[i], epoch)
                except:
                    log_value('cls_' + str(i) + '_precision', 0, epoch)
                try:
                    log_value('cls_' + str(i) + '_f1', f1[i], epoch)
                except:
                    log_value('cls_' + str(i) + '_f1', 0, epoch)
                try:
                    log_value('cls_' + str(i) + 'acc', acc[i], epoch)
                except:
                    log_value('cls_' + str(i) + 'acc', 0, epoch)

        # save scores to a tsv file, rewrite the whole file to prevent
        # accidental deletion
        print(('epoch:{}\tlr:{}\ttrain_loss:{:.4f}\ttrain_acc:{:.4f}\tval_loss:{:.4f}\tval_acc:{:.4f}')
                      .format(epoch, lr, train_loss,train_acc, val_loss, val_acc), file=f_log)
        for i in range(args.num_classes):
            try:
                print(('cls_{}_recall: {:.4f}').format(i, recall[i]), file=f_log)
            except:
                print(('cls_{}_recall: {:.4f}').format(i, 0), file=f_log)
            try:
                print(('cls_{}_precision: {:.4f}').format(i, precision[i]), file=f_log)
            except:
                print(('cls_{}_precision: {:.4f}').format(i, 0), file=f_log)
            try:
                print(('cls_{}_f1: {:.4f}').format(i, f1[i]), file=f_log)
            except:
                print(('cls_{}_f1: {:.4f}').format(i, 0), file=f_log)
            try:
                print(('cls_{}_acc: {:.4f}').format(i, acc[i]), file=f_log)
            except:
                print(('cls_{}_acc: {:.4f}').format(i, 0), file=f_log)
        f_log.flush()

        # remember best err@1 and save checkpoint
        is_best = val_acc > best_acc
        if is_best:
            best_acc = val_acc
            best_epoch = epoch
            print(Fore.GREEN + 'Best var_acc {}'.format(best_acc) + Fore.RESET, file=f_log)
        f_log.flush()

        dict = {
            'args': args,
            'epoch': epoch,
            'best_epoch': best_epoch,
            'arch': args.arch,
            'state_dict': model.module.state_dict(),
            'best_acc': best_acc,
        }
        # state_dict: model.state_dict() will add "module" layer in front of every model. The reading of this kind of
        # checkpoint requires to initialize model with DataParallel before resuming.
        save_checkpoint(dict, is_best, args.save, filename='checkpoint_' + str(epoch) + '.pth.tar')
        if not is_best and epoch - best_epoch >= args.patience > 0:
            break
    print('Best best_acc: {:.4f} at epoch {}'.format(best_acc, best_epoch), file=f_log)
Ejemplo n.º 2
0
def train_model():
    # parse arg and start experiment
    global args
    best_err1 = 100.
    best_epoch = 0

    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']
    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tesnorboard_logger installed' +
              Fore.RESET)

    # optionally resume from a checkpoint
    if args.resume:
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            old_args = checkpoint['args']
            print('Old args:')
            print(old_args)
            # set args based on checkpoint
            if args.start_epoch <= 0:
                args.start_epoch = checkpoint['epoch'] + 1
            best_epoch = args.start_epoch - 1
            best_err1 = checkpoint['best_err1']
            for name in arch_resume_names:
                if name in vars(args) and name in vars(old_args):
                    setattr(args, name, getattr(old_args, name))
            model = getModel(**vars(args))
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print(
                "=> no checkpoint found at '{}'".format(
                    Fore.RED +
                    args.resume +
                    Fore.RESET),
                file=sys.stderr)
            return
    else:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = getModel(**vars(args))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # define optimizer
    optimizer = get_optimizer(model, args)

    # set random seed
    torch.manual_seed(args.seed)

    Trainer = import_module(args.trainer).Trainer
    trainer = Trainer(model, criterion, optimizer, args)

    # create dataloader
    if args.evaluate == 'train':
        train_loader, _, _ = getDataloaders(
            splits=('train'), **vars(args))
        trainer.test(train_loader, best_epoch)
        return
    elif args.evaluate == 'val':
        _, val_loader, _ = getDataloaders(
            splits=('val'), **vars(args))
        trainer.test(val_loader, best_epoch)
        return
    elif args.evaluate == 'test':
        _, _, test_loader = getDataloaders(
            splits=('test'), **vars(args))
        trainer.test(test_loader, best_epoch)
        return
    else:
        train_loader, val_loader, _ = getDataloaders(
            splits=('train', 'val'), **vars(args))

    # check if the folder exists
    create_save_folder(args.save, args.force)

    # set up logging
    global log_print, f_log
    f_log = open(os.path.join(args.save, 'log.txt'), 'w')

    def log_print(*args):
        print(*args)
        print(*args, file=f_log)
    log_print('args:')
    log_print(args)
    print('model:', file=f_log)
    print(model, file=f_log)
    log_print('# of params:',
              str(sum([p.numel() for p in model.parameters()])))
    f_log.flush()
    torch.save(args, os.path.join(args.save, 'args.pth'))
    scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_err1'
              '\tval_err1\ttrain_err5\tval_err']
    if args.tensorboard:
        configure(args.save, flush_secs=5)

    for epoch in range(args.start_epoch, args.epochs + 1):

        # train for one epoch
        train_loss, train_err1, train_err5, lr = trainer.train(
            train_loader, epoch)

        if args.tensorboard:
            log_value('lr', lr, epoch)
            log_value('train_loss', train_loss, epoch)
            log_value('train_err1', train_err1, epoch)
            log_value('train_err5', train_err5, epoch)

        # evaluate on validation set
        val_loss, val_err1, val_err5 = trainer.test(val_loader, epoch)

        if args.tensorboard:
            log_value('val_loss', val_loss, epoch)
            log_value('val_err1', val_err1, epoch)
            log_value('val_err5', val_err5, epoch)

        # save scores to a tsv file, rewrite the whole file to prevent
        # accidental deletion
        scores.append(('{}\t{}' + '\t{:.4f}' * 6)
                      .format(epoch, lr, train_loss, val_loss,
                              train_err1, val_err1, train_err5, val_err5))
        with open(os.path.join(args.save, 'scores.tsv'), 'w') as f:
            print('\n'.join(scores), file=f)

        # remember best err@1 and save checkpoint
        is_best = val_err1 < best_err1
        if is_best:
            best_err1 = val_err1
            best_epoch = epoch
            print(Fore.GREEN + 'Best var_err1 {}'.format(best_err1) +
                  Fore.RESET)
            # test_loss, test_err1, test_err1 = validate(
            #     test_loader, model, criterion, epoch, True)
            # save test
        save_checkpoint({
            'args': args,
            'epoch': epoch,
            'best_epoch': best_epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best, args.save)
        if not is_best and epoch - best_epoch >= args.patience > 0:
            break
    print('Best val_err1: {:.4f} at epoch {}'.format(best_err1, best_epoch))
Ejemplo n.º 3
0
def main():
    # parse arg and start experiment
    global args
    best_err1 = 100.
    best_epoch = 0

    args = parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']
    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tesnorboard_logger installed' +
              Fore.RESET)

    # optionally resume from a checkpoint
    if args.resume:
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            old_args = checkpoint['args']
            print('Old args:')
            print(old_args)
            # set args based on checkpoint
            if args.start_epoch <= 0:
                args.start_epoch = checkpoint['epoch'] + 1
            best_epoch = args.start_epoch - 1
            best_err1 = checkpoint['best_err1']
            for name in arch_resume_names:
                if name in vars(args) and name in vars(old_args):
                    setattr(args, name, getattr(old_args, name))
            model = getModel(**vars(args))
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print(
                "=> no checkpoint found at '{}'".format(
                    Fore.RED +
                    args.resume +
                    Fore.RESET),
                file=sys.stderr)
            return
    else:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = getModel(**vars(args))

    cudnn.benchmark = True

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # define optimizer
    optimizer = get_optimizer(model, args)

    trainer = Trainer(model, criterion, optimizer, args)


    # create dataloader
    if args.evaluate == 'val':
        _, val_loader, _ = getDataloaders(
            splits=('val'), **vars(args))
        trainer.test(val_loader, best_epoch)
        return
    elif args.evaluate == 'test':
        _, _, test_loader = getDataloaders(
            splits=('test'), **vars(args))
        trainer.test(test_loader, best_epoch)
        return
    else:
        train_loader, val_loader, _ = getDataloaders(
            splits=('train', 'val'), **vars(args))

    # check if the folder exists
    if os.path.exists(args.save):
        print(Fore.RED + args.save + Fore.RESET
              + ' already exists!', file=sys.stderr)
        if not args.force:
            ans = input('Do you want to overwrite it? [y/N]:')
            if ans not in ('y', 'Y', 'yes', 'Yes'):
                os.exit(1)
        tmp_path = '/tmp/{}_{}'.format(os.path.basename(args.save),
                                       time.time())
        print('move existing {} to {}'.format(args.save, Fore.RED
                                              + tmp_path + Fore.RESET))
        shutil.copytree(args.save, tmp_path)
        shutil.rmtree(args.save)
    os.makedirs(args.save)
    print('create folder: ' + Fore.GREEN + args.save + Fore.RESET)

    # copy code to save folder
    if args.save.find('debug') < 0:
        shutil.copytree(
            '.',
            os.path.join(
                args.save,
                'src'),
            symlinks=True,
            ignore=shutil.ignore_patterns(
                '*.pyc',
                '__pycache__',
                '*.path.tar',
                '*.pth',
                '*.ipynb',
                '.*',
                'data',
                'save',
                'save_backup'))

    # set up logging
    global log_print, f_log
    f_log = open(os.path.join(args.save, 'log.txt'), 'w')

    def log_print(*args):
        print(*args)
        print(*args, file=f_log)
    log_print('args:')
    log_print(args)
    print('model:', file=f_log)
    print(model, file=f_log)
    log_print('# of params:',
              str(sum([p.numel() for p in model.parameters()])))
    f_log.flush()
    torch.save(args, os.path.join(args.save, 'args.pth'))
    scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_err1'
              '\tval_err1\ttrain_err5\tval_err']
    if args.tensorboard:
        configure(args.save, flush_secs=5)

    for epoch in range(args.start_epoch, args.epochs + 1):

        # train for one epoch
        train_loss, train_err1, train_err5, lr = trainer.train(
            train_loader, epoch)

        if args.tensorboard:
            log_value('lr', lr, epoch)
            log_value('train_loss', train_loss, epoch)
            log_value('train_err1', train_err1, epoch)
            log_value('train_err5', train_err5, epoch)

        # evaluate on validation set
        val_loss, val_err1, val_err5 = trainer.test(val_loader, epoch)

        if args.tensorboard:
            log_value('val_loss', val_loss, epoch)
            log_value('val_err1', val_err1, epoch)
            log_value('val_err5', val_err5, epoch)

        # save scores to a tsv file, rewrite the whole file to prevent
        # accidental deletion
        scores.append(('{}\t{}' + '\t{:.4f}' * 6)
                      .format(epoch, lr, train_loss, val_loss,
                              train_err1, val_err1, train_err5, val_err5))
        with open(os.path.join(args.save, 'scores.tsv'), 'w') as f:
            print('\n'.join(scores), file=f)

        # remember best err@1 and save checkpoint
        is_best = val_err1 < best_err1
        if is_best:
            best_err1 = val_err1
            best_epoch = epoch
            print(Fore.GREEN + 'Best var_err1 {}'.format(best_err1) +
                  Fore.RESET)
            # test_loss, test_err1, test_err1 = validate(
            #     test_loader, model, criterion, epoch, True)
            # save test
        save_checkpoint({
            'args': args,
            'epoch': epoch,
            'best_epoch': best_epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best, args.save)
        if not is_best and epoch - best_epoch >= args.patience > 0:
            break
    print('Best val_err1: {:.4f} at epoch {}'.format(best_err1, best_epoch))
Ejemplo n.º 4
0
def main():
    # parse arg and start experiment
    global args
    best_ap = -1.
    best_iter = 0

    args = parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    # args.num_classes = config.datasets[args.data]['num_classes']
    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tesnorboard_logger installed' +
              Fore.RESET)

    # optionally resume from a checkpoint
    if args.resume:
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            old_args = checkpoint['args']
            print('Old args:')
            print(old_args)
            # set args based on checkpoint
            if args.start_iter <= 0:
                args.start_iter = checkpoint['iter'] + 1
            best_iter = args.start_iter - 1
            best_ap = checkpoint['best_ap']
            for name in arch_resume_names:
                if name in vars(args) and name in vars(old_args):
                    setattr(args, name, getattr(old_args, name))
            model = get_model(**vars(args))
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (iter {})"
                  .format(args.resume, checkpoint['iter']))
        else:
            print(
                "=> no checkpoint found at '{}'".format(
                    Fore.RED +
                    args.resume +
                    Fore.RESET),
                file=sys.stderr)
            return
    else:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = get_model(**vars(args))

    # cudnn.benchmark = True
    cudnn.enabled = False

    # create dataloader
    if args.evaluate == 'val':
        train_loader, val_loader, test_loader = getDataloaders(
            splits=('val'), **vars(args))
        validate(val_loader, model, best_iter)
        return
    elif args.evaluate == 'test':
        train_loader, val_loader, test_loader = getDataloaders(
            splits=('test'), **vars(args))
        validate(test_loader, model, best_iter)
        return
    else:
        train_loader, val_loader, test_loader = getDataloaders(
            splits=('train', 'val'), **vars(args))

    # define optimizer
    optimizer = get_optimizer(model, args)

    # check if the folder exists
    if os.path.exists(args.save):
        print(Fore.RED + args.save + Fore.RESET
              + ' already exists!', file=sys.stderr)
        if not args.force:
            ans = input('Do you want to overwrite it? [y/N]:')
            if ans not in ('y', 'Y', 'yes', 'Yes'):
                os.exit(1)
        print('remove existing ' + args.save)
        shutil.rmtree(args.save)
    os.makedirs(args.save)
    print('create folder: ' + Fore.GREEN + args.save + Fore.RESET)

    # copy code to save folder
    if args.save.find('debug') < 0:
        shutil.copytree(
            '.',
            os.path.join(
                args.save,
                'src'),
            symlinks=True,
            ignore=shutil.ignore_patterns(
                '*.pyc',
                '__pycache__',
                '*.path.tar',
                '*.pth',
                '*.ipynb',
                '.*',
                'data',
                'save',
                'save_backup'))

    # set up logging
    global log_print, f_log
    f_log = open(os.path.join(args.save, 'log.txt'), 'w')

    def log_print(*args):
        print(*args)
        print(*args, file=f_log)
    log_print('args:')
    log_print(args)
    print('model:', file=f_log)
    print(model, file=f_log, flush=True)
    # log_print('model:')
    # log_print(model)
    # log_print('optimizer:')
    # log_print(vars(optimizer))
    log_print('# of params:',
              str(sum([p.numel() for p in model.parameters()])))
    torch.save(args, os.path.join(args.save, 'args.pth'))
    scores = ['iter\tlr\ttrain_loss\tval_ap']
    if args.tensorboard:
        configure(args.save, flush_secs=5)

    for i in range(args.start_iter, args.niters + 1, args.eval_freq):
        # print('iter {:3d} lr = {:.6e}'.format(i, lr))
        # if args.tensorboard:
        #     log_value('lr', lr, i)

        # train for args.eval_freq iterations
        train_loss = train(train_loader, model, optimizer,
                           i, args.eval_freq)
        i += args.eval_freq - 1

        # evaluate on validation set
        val_ap = validate(val_loader, model, i)

        # save scores to a tsv file, rewrite the whole file to prevent
        # accidental deletion
        scores.append(('{}\t{}' + '\t{:.4f}' * 2)
                      .format(i, lr, train_loss, val_ap))
        with open(os.path.join(args.save, 'scores.tsv'), 'w') as f:
            print('\n'.join(scores), file=f)

        # remember best err@1 and save checkpoint
        # TODO: change this
        is_best = val_ap > best_ap
        if is_best:
            best_ap = val_ap
            best_iter = i
            print(Fore.GREEN + 'Best var_err1 {}'.format(best_ap) +
                  Fore.RESET)
        save_checkpoint({
            'args': args,
            'iter': i,
            'best_iter': best_iter,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_ap': best_ap,
        }, is_best, args.save)
        if not is_best and i - best_iter >= args.patience > 0:
            break
    print('Best val_ap: {:.4f} at iter {}'.format(best_ap, best_iter))
Ejemplo n.º 5
0
    if test_task >= 0:
        assert accs[test_task] == accuracy
        return accs[test_task]

    if args.dset_name == 'C10' or args.dset_name == 'Fuzzy-C10':
        return np.mean(accs), accs
    else:
        return np.mean(accs[1:]), accs


from dataloader import getDataloaders

trainLoaders, _ = getDataloaders(dset_name=args.dset_name,
                                 shuffle=True,
                                 splits=['train'],
                                 data_root=args.data_dir,
                                 batch_size=args.batch_size,
                                 num_workers=0,
                                 num_tasks=num_tasks,
                                 raw=False)

testLoaders = None
if args.dset_name.find("Fuzzy") >= 0:
    _, testLoaders = getDataloaders(dset_name="C10",
                                    shuffle=True,
                                    splits=['test'],
                                    data_root=args.data_dir,
                                    batch_size=args.batch_size,
                                    num_workers=0,
                                    num_tasks=num_tasks,
                                    raw=False)
else:
Ejemplo n.º 6
0
def main():
    # parse arg and start experiment
    global args
    best_err1 = 100.
    best_epoch = 0

    args = arg_parser.parse_args()
    args.config_of_data = config.datasets[args.data]
    args.num_classes = config.datasets[args.data]['num_classes']
    if configure is None:
        args.tensorboard = False
        print(Fore.RED +
              'WARNING: you don\'t have tensorboard_logger installed' +
              Fore.RESET)

    # optionally resume from a checkpoint
    if args.resume:
        if args.resume and os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            global checkpoint
            checkpoint = torch.load(args.resume)
            old_args = checkpoint['args']
            # set args based on checkpoint
            if args.start_epoch <= 0:
                args.start_epoch = checkpoint['epoch'] + 1
            best_epoch = args.start_epoch - 1
            best_err1 = checkpoint['best_err1']
            for name in arch_resume_names: 
            #['arch', 'depth', 'death_mode', 'death_rate', 'growth_rate', 'bn_size', 'compression']
                if name in vars(args) and name in vars(old_args):
                    setattr(args, name, getattr(old_args, name))
            model = getModel(**vars(args))
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print(
                "=> no checkpoint found at '{}'".format(
                    Fore.RED +
                    args.resume +
                    Fore.RESET),
                file=sys.stderr)
            return
    else:
        # create model
        print("=> creating model '{}'".format(args.arch))
        model = getModel(**vars(args))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # define optimizer
    optimizer = get_optimizer(model, args)

    Trainer = import_module(args.trainer).Trainer
    trainer = Trainer(model, criterion, optimizer, args)

    # create dataloader
    if args.evaluate == 'train':
        train_loader, _, _ = getDataloaders(
            splits=('train'), **vars(args))
        trainer.test(train_loader, best_epoch)
        return
    elif args.evaluate == 'val':
        _, val_loader, _ = getDataloaders(
            splits=('val'), **vars(args))
        trainer.test(val_loader, best_epoch)
        return
    elif args.evaluate == 'test':
        _, _, test_loader = getDataloaders(
            splits=('test'), **vars(args))
        if args.test_death_mode == 'none':
            trainer.test(test_loader, best_epoch)
        else:
            print ("Stochastic depth testing...")
            nblocks = (args.depth - 2) // 2
            n = (args.depth - 2) // 6 
            section_reps=[n]*3

            if args.test_death_mode == 'stoch':
                all_top1 = []
                for n in range(nblocks): #drop 0, 1, 2, ..., nblocks-1 blocks
                    print ("Dropping " + str(n)+ " blocks")
                    death_rates_list = [0]*(nblocks-n) + [1]*n
                    test_death_rate = []
                    count = 0
                    for i in range(len(section_reps)):
                        test_death_rate.append(death_rates_list[count:(count+section_reps[i])])
                        count += section_reps[i]
                    model = getModel(test_death_rate=test_death_rate, **vars(args))
                    model.load_state_dict(checkpoint['state_dict'])
                    optimizer = get_optimizer(model, args)
                    trainer = Trainer(model, criterion, optimizer, args)
                    _, top1, _ = trainer.test(test_loader, best_epoch)
                    all_top1.append(top1)

                with open(args.resume.split('/')[1]+'.csv','w') as f:
                    writer = csv.writer(f)
                    rows = zip(range(0, nblocks), all_top1)
                    for row in rows:
                        writer.writerow(row)
            else:
                for n in range(1, 25):
                    all_top1 = []
                    print ("Dropping " + str(n)+ " random blocks")
                    for t in range(10): #randomly remove n blocks for 5 times
                        random_ind = random.sample(range(nblocks), n)
                        print (random_ind)
                        death_rates_list = [0]*nblocks
                        for ind in random_ind:
                            death_rates_list[ind] = 1
                        test_death_rate = []
                        count = 0
                        for i in range(len(section_reps)):
                            test_death_rate.append(death_rates_list[count:(count+section_reps[i])])
                            count += section_reps[i]
                        model = getModel(test_death_rate=test_death_rate, **vars(args))
                        model.load_state_dict(checkpoint['state_dict'])
                        optimizer = get_optimizer(model, args)
                        trainer = Trainer(model, criterion, optimizer, args)
                        _, top1, _ = trainer.test(test_loader, best_epoch)
                        all_top1.append(top1)
                    print (min(all_top1))

        return

    else:
        train_loader, val_loader, _ = getDataloaders(
            splits=('train', 'val'), **vars(args))

    # check if the folder exists
    create_save_folder(args.save, args.force)

    # set up logging
    global log_print, f_log
    f_log = open(os.path.join(args.save, 'log.txt'), 'w')

    def log_print(*args):
        print(*args)
        print(*args, file=f_log)
    log_print('args:')
    log_print(args)
    print('model:', file=f_log)
    print(model, file=f_log)
    log_print('# of params:',
              str(sum([p.numel() for p in model.parameters()])))
    f_log.flush()
    torch.save(args, os.path.join(args.save, 'args.pth'))
    scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_err1'
              '\tval_err1\ttrain_err5\tval_err']
    if args.tensorboard:
        configure(args.save, flush_secs=5)

    for epoch in range(args.start_epoch, args.epochs + 1):

        # train for one epoch
        train_loss, train_err1, train_err5, lr = trainer.train(
            train_loader, epoch)

        if args.tensorboard:
            log_value('lr', lr, epoch)
            log_value('train_loss', train_loss, epoch)
            log_value('train_err1', train_err1, epoch)
            log_value('train_err5', train_err5, epoch)

        # evaluate on validation set
        val_loss, val_err1, val_err5 = trainer.test(val_loader, epoch)

        if args.tensorboard:
            log_value('val_loss', val_loss, epoch)
            log_value('val_err1', val_err1, epoch)
            log_value('val_err5', val_err5, epoch)

        # save scores to a tsv file, rewrite the whole file to prevent
        # accidental deletion
        scores.append(('{}\t{}' + '\t{:.4f}' * 6)
                      .format(epoch, lr, train_loss, val_loss,
                              train_err1, val_err1, train_err5, val_err5))
        with open(os.path.join(args.save, 'scores.tsv'), 'w') as f:
            print('\n'.join(scores), file=f)

        # remember best err@1 and save checkpoint
        is_best = val_err1 < best_err1
        if is_best:
            best_err1 = val_err1
            best_epoch = epoch
            print(Fore.GREEN + 'Best var_err1 {}'.format(best_err1) +
                  Fore.RESET)
            # test_loss, test_err1, test_err1 = validate(
            #     test_loader, model, criterion, epoch, True)
            # save test
        save_checkpoint({
            'args': args,
            'epoch': epoch,
            'best_epoch': best_epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_err1': best_err1,
        }, is_best, args.save)
        if not is_best and epoch - best_epoch >= args.patience > 0:
            break
    print('Best val_err1: {:.4f} at epoch {}'.format(best_err1, best_epoch))