def main(args):
    global best_iou
    global idx
    global output_res
    output_res = args.out_res

    # 2020.3.2
    global REDRAW

    # 2020.3.4
    # if you do type arg.resume
    # args.checkpoint would be derived from arg.resume
    if args.resume != '':
        args.checkpoint = ('/').join(args.resume.split('/')[:2])

    if args.relabel == True:
        args.test_batch = 1
    elif args.test == True:
        args.train_batch = 1
        args.test_batch = 1
        args.epochs = 20
        # args.train_batch = 2
        # args.test_batch = 2
        # args.epochs = 10

    # write line-chart and stop program
    if args.write:
        draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt'))
        return

    # idx is the index of joints used to compute accuracy
    if args.dataset in ['mpii', 'lsp']:
        idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16]
    elif args.dataset == 'coco':
        idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    elif args.dataset == 'sad' or args.dataset == 'sad_step_2' or args.dataset == 'sad_step_2_eval':
        idx = [1]  # support affordance
    else:
        print("Unknown dataset: {}".format(args.dataset))
        assert False

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.BCELoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val Acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](
        is_train=True,
        **vars(args))  #-> depend on args.dataset to replace with datasets
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # for i, (input, input_depth, input_mask, target, meta) in enumerate(train_loader):
    #     print(len(input))
    #     print(input[0].shape)
    #     print(input_mask[0].shape)
    #     print(target[0].shape)
    #     return

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # redraw training / test label :
    global RELABEL
    if args.relabel:
        RELABEL = True
        if args.evaluate:
            print('\nRelabel val label')
            loss, acc = validate(val_loader, model, criterion, njoints,
                                 args.checkpoint, args.debug, args.flip)
            print("Val acc: %.3f" % (acc))
            return

    # evaluation only
    global JUST_EVALUATE
    JUST_EVALUATE = False
    if args.evaluate:
        print('\nEvaluation only')
        JUST_EVALUATE = True
        loss, acc = validate(val_loader, model, criterion, njoints,
                             args.checkpoint, args.debug, args.flip)
        print("Val acc: %.3f" % (acc))
        return

    ## backup when training starts
    code_backup_dir = 'code_backup'
    mkdir_p(os.path.join(args.checkpoint, code_backup_dir))
    os.system(
        'cp ../affordance/models/affordance_classification.py %s/%s/affordance_classification.py'
        % (args.checkpoint, code_backup_dir))
    os.system('cp ../affordance/models/convlstm.py %s/%s/convlstm.py' %
              (args.checkpoint, code_backup_dir))
    os.system(
        'cp ../affordance/datasets/sad_step_2_eval.py %s/%s/sad_step_2_eval.py'
        % (args.checkpoint, code_backup_dir))
    this_file_name = os.path.split(os.path.abspath(__file__))[1]
    os.system('cp ./%s %s' %
              (this_file_name,
               os.path.join(args.checkpoint, code_backup_dir, this_file_name)))

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer,
                           args.debug, args.flip)

        # evaluate on validation set
        valid_loss, valid_acc = validate(val_loader, model, criterion, njoints,
                                         args.checkpoint, args.debug,
                                         args.flip)
        print("Val acc: %.3f" % (valid_acc))

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, valid_acc])

        # remember best acc and save checkpoint
        is_best_iou = valid_acc > best_iou
        best_iou = max(valid_acc, best_iou)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_iou': best_iou,
                'optimizer': optimizer.state_dict(),
            },
            is_best_iou,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()

    print("Best acc = %.3f" % (best_iou))
def main(args):
    global best_final_acc
    global idx
    global output_res
    output_res = args.out_res

    # 2020.3.2
    global REDRAW

    # 2020.3.4
    # if you do type arg.resume
    # args.checkpoint would be derived from arg.resume

    if args.pre_train:
        # pre train lr = 5e-4
        args.lr = 5e-5

    if args.resume != '' and args.pre_train == False:
        args.checkpoint = ('/').join(args.resume.split('/')[:2])
    if args.relabel == True:
        args.test_batch = 1
    if args.test == True:
        args.train_batch = 4
        args.test_batch = 4
        args.epochs = 10

    if args.evaluate and args.relabel == False:
        args.test_batch = 4

    # write line-chart and stop program
    if args.write:
        draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt'))
        return

    idx = [1]

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion_iou = losses.IoULoss().to(device)
    criterion_bce = losses.BCELoss().to(device)
    criterion_focal = losses.FocalLoss().to(device)
    criterions = [criterion_iou, criterion_bce, criterion_focal]

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.pre_train:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            # start from epoch 0
            args.start_epoch = 0
            best_final_acc = 0
            model.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
            logger.set_names(['Epoch', 'LR', 'Train Attention Loss', 'Val Attention Loss', 'Val Attention Loss', \
                'Train Region Loss', 'Val Region Loss', 'Val Region IoU', \
                'Train Existence Acc', 'Val Existence Loss', 'Val Existence Acc', 'Val final acc'])

        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    elif args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']
            # best_iou = checkpoint['best_iou']

            # start from epoch 0
            args.start_epoch = 0
            best_final_acc = 0

            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)

        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Attention Loss', 'Val Attention Loss', 'Val Attention IoU', \
            'Train Region Loss', 'Val Region Loss', 'Val Region IoU', \
            'Train Existence Acc', 'Val Existence Loss', 'Val Existence Acc', 'Val final acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](
        is_train=True,
        **vars(args))  #-> depend on args.dataset to replace with datasets
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    '''
    for i, (input, input_depth, target_heatmap, target_mask, target_label, meta) in enumerate(train_loader):
        print(len(input))
        print(input[0].shape)
        print(input_depth[0].shape)
        print(target_heatmap[0].shape)
        print(target_mask[0].shape)
        print(target_label[0].shape)
        return
    '''

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # redraw training / test label :
    global RELABEL
    if args.relabel:
        RELABEL = True
        if args.evaluate:
            print('\nRelabel val label')
            val_att_loss, val_att_iou, val_region_loss, val_region_iou, \
                val_existence_loss, val_existence_acc , val_final_acc \
                    = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip)
            print("Val final acc: %.3f" % (val_final_acc))
            # Because test and val are all considered -> iou is uesless
            # print("Val IoU: %.3f" % (iou))
            return

    # evaluation only
    global JUST_EVALUATE
    JUST_EVALUATE = False
    if args.evaluate:
        print('\nEvaluation only')
        JUST_EVALUATE = True
        val_att_loss, val_att_iou, val_region_loss, val_region_iou, \
            val_existence_loss, val_existence_acc , val_final_acc \
                = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip)
        print("Val final acc: %.3f" % (val_final_acc))
        # print( val_att_loss, val_att_iou, val_region_loss, val_region_iou, \
        #     val_existence_loss, val_existence_acc , val_final_acc)

        return

    ## backup when training starts
    code_backup_dir = 'code_backup'
    mkdir_p(os.path.join(args.checkpoint, code_backup_dir))
    os.system(
        'cp ../affordance/models/hourglass_final.py %s/%s/hourglass_final.py' %
        (args.checkpoint, code_backup_dir))
    os.system(
        'cp ../affordance/datasets/sad_attention.py %s/%s/sad_attention.py' %
        (args.checkpoint, code_backup_dir))
    this_file_name = os.path.split(os.path.abspath(__file__))[1]
    os.system('cp ./%s %s' %
              (this_file_name,
               os.path.join(args.checkpoint, code_backup_dir, this_file_name)))

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_att_loss, train_region_loss, train_existence_loss \
            = train(train_loader, model, criterions, optimizer, args.debug, args.flip)

        # evaluate on validation set
        val_att_loss, val_att_iou, val_region_loss, val_region_iou, \
            val_existence_loss, val_existence_acc , val_final_acc \
                = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip)
        print("Val region IoU: %.3f" % (val_region_iou))
        print("Val label acc: %.3f" % (val_existence_acc))
        val_final_acc = val_region_iou + val_existence_acc

        # append logger file
        logger.append([epoch + 1, lr, train_att_loss, val_att_loss, val_att_iou, \
            train_region_loss, val_region_loss, val_region_iou, \
            train_existence_loss, val_existence_loss, val_existence_acc,
            val_final_acc])

        # remember best acc and save checkpoint
        is_best_acc = val_final_acc > best_final_acc
        best_final_acc = max(val_final_acc, best_final_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_iou': best_final_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best_acc,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()

    print("Best val final acc = %.3f" % (best_final_acc))
Exemple #3
0
def main(args):
    global best_iou
    global idx
    global output_res
    output_res = args.out_res

    # 2020.3.2
    global REDRAW

    # 2020.3.4
    # if you do type arg.resume
    # args.checkpoint would be derived from arg.resume

    if args.pre_train:
        # pre train lr = 5e-4
        args.lr = 5e-5

    if args.resume != '' and args.pre_train == False:
        args.checkpoint = ('/').join(args.resume.split('/')[:2])

    if args.relabel == True:
        args.test_batch = 1
    if args.test == True:
        args.train_batch = 4
        args.test_batch = 4
        args.epochs = 20

    if args.evaluate and args.relabel == False:
        args.test_batch = 1

    # write line-chart and stop program
    if args.write:
        draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt'))
        return

    # idx is the index of joints used to compute accuracy
    # if args.dataset in ['mpii', 'lsp']:
    #     idx = [1,2,3,4,5,6,11,12,15,16]
    # elif args.dataset == 'coco':
    #     idx = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17]
    # elif args.dataset == 'sad' or args.dataset == 'sad_step_1':
    #     idx = [1] # support affordance
    # else:
    #     print("Unknown dataset: {}".format(args.dataset))
    #     assert False
    idx = [1]

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    # 2020.6.7
    # freeze feature extraction and first one hg model paras
    # freeze_list = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, \
    #     model.hg[0], model.res[0], model.fc[0], \
    #     model.score[0],model.fc_[0], model.score_[0]]
    # for freeze_layer in freeze_list :
    #     for param in freeze_layer.parameters():
    #         param.requires_grad = False

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.IoULoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
        # optimizer = torch.optim.Adam(
        #     filter(lambda p: p.requires_grad, model.parameters()),
        #     lr=args.lr,
        # )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.pre_train:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            # start from epoch 0
            args.start_epoch = 0
            best_iou = 0
            model.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
            logger.set_names(
                ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val IoU'])
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    elif args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']
            # best_iou = checkpoint['best_iou']

            # start from epoch 0
            args.start_epoch = 0
            best_iou = 0

            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)

        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val IoU'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](
        is_train=True,
        **vars(args))  #-> depend on args.dataset to replace with datasets
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    '''
    for i, (input, input_depth, target, meta) in enumerate(train_loader):
        print(len(input))
        print(input[0].shape)
        print(input_depth[0].shape)
        print(target[0].shape)
        return
    '''

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # redraw training / test label :
    global RELABEL
    if args.relabel:
        RELABEL = True
        if args.evaluate:
            print('\nRelabel val label')
            new_checkpoint = 'checkpoint_0701_bbox_hide'
            mkdir_p(new_checkpoint)
            loss, iou, predictions = validate(val_loader, model, criterion,
                                              njoints, new_checkpoint,
                                              args.debug, args.flip)
            # Because test and val are all considered -> iou is uesless
            # print("Val IoU: %.3f" % (iou))
            return

    # evaluation only
    global JUST_EVALUATE
    JUST_EVALUATE = False
    if args.evaluate:
        print('\nEvaluation only')
        if args.debug:
            print('Draw pred /gt heatmap')
        JUST_EVALUATE = True
        new_checkpoint = 'checkpoint_0701_bbox_hide'
        mkdir_p(new_checkpoint)
        loss, iou, predictions = validate(val_loader, model, criterion,
                                          njoints, new_checkpoint, args.debug,
                                          args.flip)
        print("Val IoU: %.3f" % (iou))
        return

    ## backup when training starts
    code_backup_dir = 'code_backup'
    mkdir_p(os.path.join(args.checkpoint, code_backup_dir))
    os.system('cp ../affordance/models/hourglass.py %s/%s/hourglass.py' %
              (args.checkpoint, code_backup_dir))
    os.system('cp ../affordance/datasets/sad.py %s/%s/sad.py' %
              (args.checkpoint, code_backup_dir))
    this_file_name = os.path.split(os.path.abspath(__file__))[1]
    os.system('cp ./%s %s' %
              (this_file_name,
               os.path.join(args.checkpoint, code_backup_dir, this_file_name)))

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer,
                           args.debug, args.flip)

        # evaluate on validation set
        valid_loss, valid_iou, predictions = validate(val_loader, model,
                                                      criterion, njoints,
                                                      arg.checkpoint,
                                                      args.debug, args.flip)
        print("Val IoU: %.3f" % (valid_iou))

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, valid_iou])

        # remember best acc and save checkpoint
        is_best_iou = valid_iou > best_iou
        best_iou = max(valid_iou, best_iou)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_iou': best_iou,
                'optimizer': optimizer.state_dict(),
            },
            is_best_iou,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()

    print("Best iou = %.3f" % (best_iou))
def main(args):
    global best_iou
    global idx
    global output_res
    output_res = args.out_res

    # 2020.3.2
    global REDRAW

    # 2020.3.4
    # if you do type arg.resume
    # args.checkpoint would be derived from arg.resume
    if args.resume != '':
        args.checkpoint = ('/').join(args.resume.split('/')[:2])

    if args.relabel == True:
        args.test_batch = 1
    elif args.test == True:
        # args.train_batch = 4
        # args.test_batch = 4
        # args.epochs = 20
        args.train_batch = 2
        args.test_batch = 2
        args.epochs = 10

    # write line-chart and stop program
    if args.write:
        draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt'))
        return

    # idx is the index of joints used to compute accuracy
    if args.dataset in ['mpii', 'lsp']:
        idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16]
    elif args.dataset == 'coco':
        idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    elif args.dataset == 'sad' or args.dataset == 'sad_step_2' or args.dataset == 'sad_step_2_eval':
        idx = [1]  # support affordance
    else:
        print("Unknown dataset: {}".format(args.dataset))
        assert False

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.BCELoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val Acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](
        is_train=True,
        **vars(args))  #-> depend on args.dataset to replace with datasets
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # redraw training / test label :
    global RELABEL
    if args.relabel:
        RELABEL = True
        if args.evaluate:
            print('\nRelabel val label')
            loss, acc = validate(val_loader, model, criterion, njoints,
                                 args.checkpoint, args.debug, args.flip)
            print("Val acc: %.3f" % (acc))
            return

    # evaluation only
    global JUST_EVALUATE
    JUST_EVALUATE = False
    if args.evaluate:
        print('\nEvaluation only')
        JUST_EVALUATE = True
        loss, acc = validate(val_loader, model, criterion, njoints,
                             args.checkpoint, args.debug, args.flip)
        print("Val acc: %.3f" % (acc))
        return