def main(args):

    args.cuda = args.use_cuda and torch.cuda.is_available()

    train_set, validate_set, test_set, train_loader, validate_loader, test_loader = get_data.get_data_atomic(
        args)

    #model = models.Atomic(args)
    model = models.Atomic_edge_only(args)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    #optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    #{'single': 0, 'mutual': 1, 'avert': 2, 'refer': 3, 'follow': 4, 'share': 5}
    criterion = [
        torch.nn.CrossEntropyLoss(
            weight=torch.Tensor([0.05, 0.05, 0.25, 0.25, 0.25, 0.15])),
        torch.nn.MSELoss()
    ]

    # {'NA': 0, 'single': 1, 'mutual': 2, 'avert': 3, 'refer': 4, 'follow': 5, 'share': 6}

    scheduler = ReduceLROnPlateau(optimizer,
                                  factor=args.lr_decay,
                                  patience=1,
                                  verbose=True,
                                  mode='max')
    #--------------------------------------------------
    # ------------------------
    # use multi-gpu

    if args.cuda and torch.cuda.device_count() > 1:
        print("Now Using ", len(args.device_ids), " GPUs!")

        model = torch.nn.DataParallel(model,
                                      device_ids=args.device_ids,
                                      output_device=args.device_ids[0]).cuda()
        #model=model.cuda()
        criterion[0] = criterion[0].cuda()
        criterion[1] = criterion[1].cuda()

    elif args.cuda:
        model = model.cuda()
        criterion[0] = criterion[0].cuda()
        criterion[1] = criterion[1].cuda()

    if args.load_best_checkpoint:
        loaded_checkpoint = utils.load_best_checkpoint(args,
                                                       model,
                                                       optimizer,
                                                       path=args.resume)

        if loaded_checkpoint:
            args, best_epoch_acc, avg_epoch_acc, model, optimizer = loaded_checkpoint

    if args.load_last_checkpoint:
        loaded_checkpoint = utils.load_last_checkpoint(
            args,
            model,
            optimizer,
            path=args.resume,
            version=args.model_load_version)

        if loaded_checkpoint:
            args, best_epoch_acc, avg_epoch_acc, model, optimizer = loaded_checkpoint

            # ------------------------------------------------------------------------------
            # Start Training!

    since = time.time()

    train_epoch_acc_all = []
    val_epoch_acc_all = []

    best_acc = 0
    avg_epoch_acc = 0

    for epoch in range(args.start_epoch, args.epochs):

        train_epoch_loss, train_epoch_acc = train(train_loader, model,
                                                  criterion, optimizer, epoch,
                                                  args)
        train_epoch_acc_all.append(train_epoch_acc)

        val_epoch_loss, val_epoch_acc = validate(validate_loader, model,
                                                 criterion, epoch, args)
        val_epoch_acc_all.append(val_epoch_acc)

        print('Epoch {}/{} Training Acc: {:.4f} Validation Acc: {:.4f}'.format(
            epoch, args.epochs - 1, train_epoch_acc, val_epoch_acc))
        print('*' * 15)

        scheduler.step(val_epoch_acc)

        is_best = val_epoch_acc > best_acc

        if is_best:
            best_acc = val_epoch_acc

        avg_epoch_acc = np.mean(val_epoch_acc_all)

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_epoch_acc': best_acc,
                'avg_epoch_acc': avg_epoch_acc,
                'optimizer': optimizer.state_dict(),
                'args': args
            },
            is_best=is_best,
            directory=args.resume,
            version='epoch_{}'.format(str(epoch)))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Val Acc: {},  Final Avg Val Acc: {}'.format(
        best_acc, avg_epoch_acc))

    # ----------------------------------------------------------------------------------------------------------
    # test

    loaded_checkpoint = utils.load_best_checkpoint(args,
                                                   model,
                                                   optimizer,
                                                   path=args.resume)

    if loaded_checkpoint:
        args, best_epoch_acc, avg_epoch_acc, model, optimizer = loaded_checkpoint

    test_loader.dataset.round_cnt = {
        'single': 0,
        'mutual': 0,
        'avert': 0,
        'refer': 0,
        'follow': 0,
        'share': 0
    }
    test_loss, test_acc, confmat, top2_acc = test(test_loader, model,
                                                  criterion, args)

    # save test results
    if not isdir(args.save_test_res):
        os.mkdir(args.save_test_res)

    with open(os.path.join(args.save_test_res, 'raw_test_results.pkl'),
              'w') as f:
        pickle.dump([test_loss, test_acc, confmat, top2_acc], f)

    print("Test Acc {}".format(test_acc))
    print("Top 2 Test Acc {}".format(top2_acc))

    # todo: need to change the mode here!
    get_metric_from_confmat(confmat, 'atomic')
Ejemplo n.º 2
0
def main(args):

    args.cuda = args.use_cuda and torch.cuda.is_available()

    train_set, validate_set, test_set, train_loader, validate_loader, test_loader = get_data.get_data_attmat(
        args)

    model = models.AttMat(args)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    criterion = torch.nn.CrossEntropyLoss()

    scheduler = ReduceLROnPlateau(optimizer,
                                  factor=args.lr_decay,
                                  patience=1,
                                  verbose=True,
                                  mode='max')

    #------------------------
    # use multi-gpu

    if args.cuda and torch.cuda.device_count() > 1:
        print("Now Using ", len(args.device_ids), " GPUs!")

        model = torch.nn.DataParallel(model,
                                      device_ids=args.device_ids,
                                      output_device=args.device_ids[0]).cuda()
        criterion = criterion.cuda()

    elif args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    if args.load_best_checkpoint:
        loaded_checkpoint = utils.load_best_checkpoint(args,
                                                       model,
                                                       optimizer,
                                                       path=args.resume)

        if loaded_checkpoint:
            args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

    if args.load_last_checkpoint:
        loaded_checkpoint = utils.load_last_checkpoint(
            args,
            model,
            optimizer,
            path=args.resume,
            version=args.model_load_version)

        if loaded_checkpoint:
            args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint


# ------------------------------------------------------------------------------
# Start Training!

    since = time.time()

    train_epoch_acc_all = []
    val_epoch_acc_all = []

    best_acc = 0
    avg_epoch_acc = 0

    for epoch in range(args.start_epoch, args.epochs):

        train_epoch_loss, train_epoch_acc = train(train_loader, model,
                                                  criterion, optimizer, epoch,
                                                  args)
        train_epoch_acc_all.append(train_epoch_acc)

        val_epoch_loss, val_epoch_acc = validate(validate_loader, model,
                                                 criterion, epoch, args)
        val_epoch_acc_all.append(val_epoch_acc)

        print('Epoch {}/{} Training Acc: {:.4f} Validation Acc: {:.4f}'.format(
            epoch, args.epochs - 1, train_epoch_acc, val_epoch_acc))
        print('*' * 15)

        scheduler.step(val_epoch_acc)

        is_best = val_epoch_acc > best_acc

        if is_best:
            best_acc = val_epoch_acc

        avg_epoch_acc = np.mean(val_epoch_acc_all)

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_epoch_acc': best_acc,
                'avg_epoch_acc': avg_epoch_acc,
                'optimizer': optimizer.state_dict(),
                'args': args
            },
            is_best=is_best,
            directory=args.resume,
            version='epoch_{}'.format(str(epoch)))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Val Acc: {},  Final Avg Val Acc: {}'.format(
        best_acc, avg_epoch_acc))

    #----------------------------------------------------------------------------------------------------------
    # test

    loaded_checkpoint = utils.load_best_checkpoint(args,
                                                   model,
                                                   optimizer,
                                                   path=args.resume)

    if loaded_checkpoint:
        args, best_epoch_acc, avg_epoch_acc, model, optimizer = loaded_checkpoint

    #
    test_loss, test_acc, one_acc, zero_acc = test(test_loader, model,
                                                  criterion, args)

    print("Test Acc {}, One Acc {}, Zero Acc {}".format(
        test_acc, one_acc, zero_acc))

    # save test results
    if not isdir(args.save_test_res):
        os.mkdir(args.save_test_res)

    with open(os.path.join(args.save_test_res, 'raw_test_results.pkl'),
              'w') as f:
        pickle.dump([test_loss, test_acc, one_acc, zero_acc], f)
Ejemplo n.º 3
0
def main(args):
    args.cuda = args.use_cuda and torch.cuda.is_available()

    train_set, validate_set, test_set, train_loader, validate_loader, test_loader = get_data.get_data_AttMat_msg_lstm(
        args)

    model_args = {'roi_feature_size': args.roi_feature_size, 'edge_feature_size': args.roi_feature_size,
                  'node_feature_size': args.roi_feature_size,
                  'message_size': args.message_size, 'link_hidden_size': args.link_hidden_size,
                  'link_hidden_layers': args.link_hidden_layers, 'propagate_layers': args.propagate_layers,
                  'big_attr_classes': args.big_attr_class_num, 'lstm_hidden_size': args.lstm_hidden_size}

    model = models.AttMat_msg_lstm(model_args, args)

    # TODO: check grads and then to set the learning rate for Adam
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # criterion=torch.nn.CrossEntropyLoss()

    if args.cuda:
        model = model.cuda()
        # criterion=criterion.cuda( )

    if args.load_best_checkpoint:
        loaded_checkpoint = utils.load_best_checkpoint(args, model, optimizer, path=args.resume)

        if loaded_checkpoint:
            args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

    if args.load_last_checkpoint:
        loaded_checkpoint = utils.load_last_checkpoint(args, model, optimizer, path=args.resume,
                                                       version=args.model_load_version)
        if loaded_checkpoint:
            args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

    train_error_history = list()
    train_loss_history = list()
    val_error_history = list()
    val_loss_history = list()

    best_epoch_error = np.inf

    for epoch in range(args.start_epoch, args.epochs):

        train_error_rate_cur_epoch, train_loss_cur_epoch = train(train_loader, model, criterion, optimizer, epoch, args)
        train_error_history.append(train_error_rate_cur_epoch)
        train_loss_history.append(train_loss_cur_epoch)

        val_error_rate_cur_epoch, val_loss_cur_epoch = validate(validate_loader, model, criterion, args)
        val_error_history.append(val_error_rate_cur_epoch)
        val_loss_history.append(val_loss_cur_epoch)

        # TODO: why use this schedule for adjusting learning rate, there is no need to decrease lr for Adam every epoch
        if epoch > 0 and epoch % 1 == 0:
            args.lr *= args.lr_decay

            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr

        is_best = val_error_rate_cur_epoch < best_epoch_error

        best_epoch_error = min(val_error_rate_cur_epoch, best_epoch_error)

        avg_epoch_error = np.mean(val_error_history)

        utils.save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_epoch_error': best_epoch_error,
            'avg_epoch_error': avg_epoch_error,
            'optimizer': optimizer.state_dict(), }, is_best=is_best, directory=args.resume,
            version='epoch_{}'.format(str(epoch)))

        print('best_epoch_error: {}, avg_epoch_error: {}'.format(best_epoch_error, avg_epoch_error))

    # test
    # loaded_checkpoint=utils.load_best_checkpoint(args,model,optimizer,path=args.resume)
    loaded_checkpoint = utils.load_last_checkpoint(args, model, optimizer, path=args.resume,
                                                   version=args.model_load_version)
    if loaded_checkpoint:
        args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

    test(test_loader, model, args)
def main(args):

    args.cuda = args.use_cuda and torch.cuda.is_available()

    train_set, validate_set, test_set, train_loader, validate_loader, test_loader = get_data.get_data_atomic(
        args)

    #model = models.Atomic(args)
    model = models.Atomic_2branch(args)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    #{'single': 0, 'mutual': 1, 'avert': 2, 'refer': 3, 'follow': 4, 'share': 5}
    criterion = [
        torch.nn.CrossEntropyLoss(
            weight=torch.Tensor([0.05, 0.05, 0.25, 0.25, 0.25, 0.15])),
        torch.nn.MSELoss()
    ]

    # {'NA': 0, 'single': 1, 'mutual': 2, 'avert': 3, 'refer': 4, 'follow': 5, 'share': 6}

    scheduler = ReduceLROnPlateau(optimizer,
                                  factor=args.lr_decay,
                                  patience=1,
                                  verbose=True,
                                  mode='max')
    #--------------------------------------------------
    # ------------------------
    # use multi-gpu

    if args.cuda and torch.cuda.device_count() > 1:
        print("Now Using ", len(args.device_ids), " GPUs!")

        model = torch.nn.DataParallel(model,
                                      device_ids=args.device_ids,
                                      output_device=args.device_ids[0]).cuda()
        #model=model.cuda()
        criterion[0] = criterion[0].cuda()
        criterion[1] = criterion[1].cuda()

    elif args.cuda:
        model = model.cuda()
        criterion[0] = criterion[0].cuda()
        criterion[1] = criterion[1].cuda()

    # ----------------------------------------------------------------------------------------------------------
    # test

    loaded_checkpoint = utils.load_best_checkpoint(args,
                                                   model,
                                                   optimizer,
                                                   path=args.resume)

    if loaded_checkpoint:
        args, best_epoch_acc, avg_epoch_acc, model, optimizer = loaded_checkpoint

    test_loader.dataset.round_cnt = {
        'single': 0,
        'mutual': 0,
        'avert': 0,
        'refer': 0,
        'follow': 0,
        'share': 0
    }
    test_loss, test_acc, confmat, top2_acc, correct_rec, error_rec = test(
        test_loader, model, criterion, args)

    #fw = open(os.path.join(args.tmp_root, 'correct_list.txt'), 'w')
    fc_single = open(os.path.join(args.tmp_root, 'correct_single.txt'), 'w')
    fc_mutual = open(os.path.join(args.tmp_root, 'correct_mutual.txt'), 'w')
    fc_avert = open(os.path.join(args.tmp_root, 'correct_avert.txt'), 'w')
    fc_refer = open(os.path.join(args.tmp_root, 'correct_refer.txt'), 'w')
    fc_follow = open(os.path.join(args.tmp_root, 'correct_follow.txt'), 'w')
    fc_share = open(os.path.join(args.tmp_root, 'correct_share.txt'), 'w')

    fe_single = open(os.path.join(args.tmp_root, 'error_single.txt'), 'w')
    fe_mutual = open(os.path.join(args.tmp_root, 'error_mutual.txt'), 'w')
    fe_avert = open(os.path.join(args.tmp_root, 'error_avert.txt'), 'w')
    fe_refer = open(os.path.join(args.tmp_root, 'error_refer.txt'), 'w')
    fe_follow = open(os.path.join(args.tmp_root, 'error_follow.txt'), 'w')
    fe_share = open(os.path.join(args.tmp_root, 'error_share.txt'), 'w')

    for item in correct_rec:

        if item[1] == '0':
            for i in range(5):
                fc_single.write(str(item[0][i].cpu().numpy()))
                fc_single.write(' ')
            fc_single.write(str(item[1].cpu().numpy()))
            fc_single.write(' ')
            fc_single.write(str(item[2].cpu().numpy()))
            fc_single.write('\n')

        elif item[1] == '1':
            for i in range(5):
                fc_mutual.write(str(item[0][i].cpu().numpy()))
                fc_mutual.write(' ')
            fc_mutual.write(str(item[1].cpu().numpy()))
            fc_mutual.write(' ')
            fc_mutual.write(str(item[2].cpu().numpy()))
            fc_mutual.write('\n')

        elif item[1] == '2':
            for i in range(5):
                fc_avert.write(str(item[0][i].cpu().numpy()))
                fc_avert.write(' ')
            fc_avert.write(str(item[1].cpu().numpy()))
            fc_avert.write(' ')
            fc_avert.write(str(item[2].cpu().numpy()))
            fc_avert.write('\n')

        elif item[1] == '3':

            for i in range(5):
                fc_refer.write(str(item[0][i].cpu().numpy()))
                fc_refer.write(' ')
            fc_refer.write(str(item[1].cpu().numpy()))
            fc_refer.write(' ')
            fc_refer.write(str(item[2].cpu().numpy()))
            fc_refer.write('\n')

        elif item[1] == '4':

            for i in range(5):
                fc_follow.write(str(item[0][i].cpu().numpy()))
                fc_follow.write(' ')
            fc_follow.write(str(item[1].cpu().numpy()))
            fc_follow.write(' ')
            fc_follow.write(str(item[2].cpu().numpy()))
            fc_follow.write('\n')

        elif item[1] == '5':

            for i in range(5):
                fc_share.write(str(item[0][i].cpu().numpy()))
                fc_share.write(' ')
            fc_share.write(str(item[1].cpu().numpy()))
            fc_share.write(' ')
            fc_share.write(str(item[2].cpu().numpy()))
            fc_share.write('\n')

    #fw2 = open(os.path.join(args.tmp_root, 'error_list.txt'), 'w')

    # for item in error_rec:
    #     for i in range(5):
    #         fw2.write(str(item[0][i].cpu().numpy()))
    #         fw2.write(' ')
    #     fw2.write(str(item[1].cpu().numpy()))
    #     fw2.write(' ')
    #     fw2.write(str(item[2].cpu().numpy()))
    #     fw2.write('\n')

    for item in error_rec:

        if item[1] == '0':
            for i in range(5):
                fe_single.write(str(item[0][i].cpu().numpy()))
                fe_single.write(' ')
            fe_single.write(str(item[1].cpu().numpy()))
            fe_single.write(' ')
            fe_single.write(str(item[2].cpu().numpy()))
            fe_single.write('\n')

        elif item[1] == '1':
            for i in range(5):
                fe_mutual.write(str(item[0][i].cpu().numpy()))
                fe_mutual.write(' ')
            fe_mutual.write(str(item[1].cpu().numpy()))
            fe_mutual.write(' ')
            fe_mutual.write(str(item[2].cpu().numpy()))
            fe_mutual.write('\n')

        elif item[1] == '2':
            for i in range(5):
                fe_avert.write(str(item[0][i].cpu().numpy()))
                fe_avert.write(' ')
            fe_avert.write(str(item[1].cpu().numpy()))
            fe_avert.write(' ')
            fe_avert.write(str(item[2].cpu().numpy()))
            fe_avert.write('\n')

        elif item[1] == '3':

            for i in range(5):
                fe_refer.write(str(item[0][i].cpu().numpy()))
                fe_refer.write(' ')
            fe_refer.write(str(item[1].cpu().numpy()))
            fe_refer.write(' ')
            fe_refer.write(str(item[2].cpu().numpy()))
            fe_refer.write('\n')

        elif item[1] == '4':

            for i in range(5):
                fe_follow.write(str(item[0][i].cpu().numpy()))
                fe_follow.write(' ')
            fe_follow.write(str(item[1].cpu().numpy()))
            fe_follow.write(' ')
            fe_follow.write(str(item[2].cpu().numpy()))
            fe_follow.write('\n')

        elif item[1] == '5':

            for i in range(5):
                fe_share.write(str(item[0][i].cpu().numpy()))
                fe_share.write(' ')
            fe_share.write(str(item[1].cpu().numpy()))
            fe_share.write(' ')
            fe_share.write(str(item[2].cpu().numpy()))
            fe_share.write('\n')
Ejemplo n.º 5
0
def main(args):


    args.cuda = args.use_cuda and torch.cuda.is_available()

    train_set, validate_set, test_set, train_loader, validate_loader, test_loader = get_data.get_data_resnet_msgpassing_balanced_lstm(args)

    model_args = {'roi_feature_size':args.roi_feature_size,'edge_feature_size': args.roi_feature_size, 'node_feature_size': args.roi_feature_size,
                  'message_size': args.message_size, 'link_hidden_size': args.link_hidden_size,
                  'link_hidden_layers': args.link_hidden_layers, 'propagate_layers': args.propagate_layers,
                   'big_attr_classes': args.big_attr_class_num, 'lstm_hidden_size':args.lstm_hidden_size}


    model = models.HGNN_resnet_msgpassing_balanced_lstm(model_args)

    optimizer=torch.optim.Adam(model.parameters(),lr=args.lr)

    criterion=torch.nn.CrossEntropyLoss()


    if args.cuda:
        model=model.cuda()
        criterion=criterion.cuda( )

    if args.load_best_checkpoint:
        loaded_checkpoint=utils.load_best_checkpoint(args,model,optimizer,path=args.resume)

        if loaded_checkpoint:
             args, best_epoch_error, avg_epoch_error, model, optimizer=loaded_checkpoint

    if args.load_last_checkpoint:
        loaded_checkpoint=utils.load_last_checkpoint(args,model,optimizer,path=args.resume)

        if loaded_checkpoint:
             args, best_epoch_error, avg_epoch_error, model, optimizer=loaded_checkpoint


    train_error_history=list()
    train_loss_history=list()
    val_error_history=list()
    val_loss_history=list()

    best_epoch_error=np.inf


    for epoch in range(args.start_epoch, args.epochs):


        train_error_rate_cur_epoch,  train_loss_cur_epoch=train(train_loader,model,criterion,optimizer,epoch,args)
        train_error_history.append(train_error_rate_cur_epoch)
        train_loss_history.append(train_loss_cur_epoch)

        val_error_rate_cur_epoch, val_loss_cur_epoch=validate(validate_loader,model,criterion,args)
        val_error_history.append(val_error_rate_cur_epoch)
        val_loss_history.append(val_loss_cur_epoch)


        if epoch>0 and epoch%1==0:
            args.lr*=args.lr_decay

            for param_group in optimizer.param_groups:
                param_group['lr']=args.lr


        is_best=val_error_rate_cur_epoch<best_epoch_error

        best_epoch_error=min(val_error_rate_cur_epoch,best_epoch_error)

        avg_epoch_error=np.mean(val_error_history)

        utils.save_checkpoint({
            'epoch':epoch+1,
            'state_dict':model.state_dict(),
            'best_epoch_error':best_epoch_error,
            'avg_epoch_error':avg_epoch_error,
            'optimizer':optimizer.state_dict(),},is_best=is_best,directory=args.resume)

        print('best_epoch_error: {}, avg_epoch_error: {}'.format(best_epoch_error, avg_epoch_error))


    # test
    #loaded_checkpoint=utils.load_best_checkpoint(args,model,optimizer,path=args.resume)
    loaded_checkpoint = utils.load_last_checkpoint(args, model, optimizer, path=args.resume)
    if loaded_checkpoint:
        args, best_epoch_error, avg_epoch_error, model, optimizer=loaded_checkpoint

    test(test_loader,model,args)
def main(args):

    args.cuda = args.use_cuda and torch.cuda.is_available()
    train_set, validate_set, test_set, train_loader, validate_loader, test_loader = get_data.get_data_resnet_fc(
        args)
    model = models.HGNN_resnet_fc()
    # TODO: try to use the step policy for Adam, also consider the step interval
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    # TODO: related to line 99 (e.g. max for auc, min for loss; try to check the definition of this method)
    scheduler = ReduceLROnPlateau(optimizer,
                                  factor=args.lr_decay,
                                  patience=1,
                                  verbose=True,
                                  mode='min')
    # TODO; double check this loss, also the output of the network
    criterion = torch.nn.CrossEntropyLoss()
    # criterion=torch.nn.MSELoss()

    #------------------------
    # use multi-gpu

    if args.cuda and torch.cuda.device_count() > 1:
        print("Now Using ", len(args.device_ids), " GPUs!")

        #model=model.to(device_ids[0])
        model = torch.nn.DataParallel(model,
                                      device_ids=args.device_ids,
                                      output_device=args.device_ids[0]).cuda()
        criterion = criterion.cuda()

    elif args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    if args.load_best_checkpoint:
        loaded_checkpoint = utils.load_best_checkpoint(args,
                                                       model,
                                                       optimizer,
                                                       path=args.resume)

        if loaded_checkpoint:
            args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

    if args.load_last_checkpoint:
        loaded_checkpoint = utils.load_last_checkpoint(
            args,
            model,
            optimizer,
            path=args.resume,
            version=args.model_load_version)

        if loaded_checkpoint:
            args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

        # for param_group in optimizer.param_groups:
        #     param_group['lr'] = args.lr
    #------------------------------------------------------------------------------
    # Train

    since = time.time()

    train_epoch_loss_all = []
    val_epoch_loss_all = []

    best_loss = np.inf
    avg_epoch_loss = np.inf

    for epoch in range(args.start_epoch, args.epochs):

        train_epoch_loss = train(train_loader, model, criterion, optimizer,
                                 epoch, args)
        train_epoch_loss_all.append(train_epoch_loss)
        #visdom_viz(vis, train_epoch_loss_all, win=0, ylabel='Training Epoch Loss', title=args.project_name, color='green')

        val_epoch_loss = validate(validate_loader, model, criterion, epoch,
                                  args)
        val_epoch_loss_all.append(val_epoch_loss)
        #visdom_viz(vis, val_epoch_loss_all, win=1, ylabel='Validation Epoch Loss', title=args.project_name,color='blue')

        print(
            'Epoch {}/{} Training Loss: {:.4f} Validation Loss: {:.4f}'.format(
                epoch, args.epochs - 1, train_epoch_loss, val_epoch_loss))
        print('*' * 15)

        #TODO: reducing lr when there is no gains on validation metric results (e.g. auc, loss)
        scheduler.step(val_epoch_loss)

        is_best = val_epoch_loss < best_loss

        if is_best:
            best_loss = val_epoch_loss

        avg_epoch_loss = np.mean(val_epoch_loss_all)

        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_epoch_error': best_loss,
                'avg_epoch_error': avg_epoch_loss,
                'optimizer': optimizer.state_dict()
            },
            is_best=is_best,
            directory=args.resume,
            version='epoch_{}'.format(str(epoch)))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Val Loss: {},  Final Avg Val Loss: {}'.format(
        best_loss, avg_epoch_loss))

    #-------------------------------------------------------------------------------------------------------------
    # test
    loaded_checkpoint = utils.load_best_checkpoint(args,
                                                   model,
                                                   optimizer,
                                                   path=args.resume)
    if loaded_checkpoint:
        args, best_epoch_error, avg_epoch_error, model, optimizer = loaded_checkpoint

    pred_label, gt_label, test_loss = test(test_loader, model, criterion, args)

    print("Test Epoch Loss {}".format(test_loss))

    # save test results
    if not isdir(args.save_test_res):
        os.mkdir(args.save_test_res)

    with open(os.path.join(args.save_test_res, 'raw_test_results.pkl'),
              'w') as f:
        pickle.dump([pred_label, gt_label, test_loss], f)

    #todo: check get_test_metric
    recall, precision, F_one_score, acc, avg_acc, ConfMat = get_test_metric(
        pred_label, gt_label, args)

    print(
        '[====Test results Small Attr====] \n recall: {} \n precision: {} \n F1 score: {} \n acc: {} \n avg acc: {} \n Confusion Matrix: \n {}'
        .format(recall, precision, F_one_score, acc, avg_acc, ConfMat))