def train(model, criterion_softmax, criterion_binary, train_set, val_set, opt):
    # define web visualizer using visdom
    #webvis = WebVisualizer(opt)

    # modify learning rate of last layer
    finetune_params = modify_last_layer_lr(model.named_parameters(),
                                           opt.lr, opt.lr_mult_w, opt.lr_mult_b)
    # define optimizer
    optimizer = optim.SGD(finetune_params,
                          opt.lr,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
    # define laerning rate scheluer
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=opt.lr_decay_in_epoch,
                                          gamma=opt.gamma)


    # record forward and backward times
    train_batch_num = len(train_set)
    total_batch_iter = 0
    logging.info("####################Train Model###################")
    for epoch in range(opt.sum_epoch):
       # epoch_start_t = time.time()
        epoch_batch_iter = 0
        logging.info('Begin of epoch %d' % (epoch))
        for i, data in enumerate(train_set):
           # iter_start_t = time.time()
            # train
            inputs, target_softmax,target_binary = data
            output, loss, loss_list = forward_batch(model, criterion_softmax, criterion_binary, inputs, target_softmax,target_binary, opt, "Train")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

           # webvis.reset()
            epoch_batch_iter += 1
            total_batch_iter += 1




      #  logging.info('End of epoch %d / %d \t Time Taken: %d sec' %
                    # (epoch, opt.sum_epoch, time.time() - epoch_start_t))

        if epoch % opt.save_epoch_freq == 0:
            logging.info('saving the model at the end of epoch %d, iters %d' % (epoch + 1, total_batch_iter))
            save_model(model, opt, epoch + 1)

            # adjust learning rate
        scheduler.step()
        lr = optimizer.param_groups[0]['lr']
        logging.info('learning rate = %.7f epoch = %d' % (lr, epoch))
    logging.info("--------Optimization Done--------")
def train(model, criterion, train_set, val_set, opt, labels=None):
    # define web visualizer using visdom
    webvis = WebVisualizer(opt)

    # modify learning rate of last layer
    finetune_params = modify_last_layer_lr(model.named_parameters(), opt.lr,
                                           opt.lr_mult_w, opt.lr_mult_b)
    # define optimizer
    #optimizer = optim.Adam(finetune_params,
    #                      opt.lr)
    # define laerning rate scheluer'
    optimizer = optim.Adam(finetune_params, 0.000001)
    #scheduler = optim.lr_scheduler.StepLR(optimizer,
    #                                      step_size=opt.lr_decay_in_epoch,
    #                                      gamma=opt.gamma)

    if labels is not None:
        rid2name, id2rid = labels

    # record forward and backward times
    train_batch_num = len(train_set)
    total_batch_iter = 0
    logging.info("####################Train Model###################")

    for epoch in range(opt.sum_epoch):
        epoch_start_t = time.time()
        epoch_batch_iter = 0
        logging.info('Begin of epoch %d' % (epoch))
        for i, data in enumerate(train_set):
            iter_start_t = time.time()
            # train

            inputs, targets = data
            #print(i,targets)
            if opt.mode == 'Train':
                output, loss, loss_list = forward_batch(
                    model, criterion, inputs, targets, opt, "Train")
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            elif opt.mode == 'Test-Train':
                #use batchsize==1
                output, loss, loss_list = forward_batch(
                    model, criterion, inputs, targets, opt, "Test-Train")
                batch_accuracy = calc_accuracy(output, targets,
                                               opt.score_thres, opt, opt.top_k)
                if batch_accuracy[1] >= THRES:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            webvis.reset()
            epoch_batch_iter += 1
            total_batch_iter += 1

            # display train loss and accuracy
            if total_batch_iter % opt.display_train_freq == 0:
                # accuracy
                batch_accuracy = calc_accuracy(output, targets,
                                               opt.score_thres, opt, opt.top_k)
                util.print_loss(loss_list, "Train", epoch, total_batch_iter)
                util.print_accuracy(batch_accuracy, "Train", epoch,
                                    total_batch_iter)
                if opt.display_id > 0:
                    x_axis = epoch + float(epoch_batch_iter) / train_batch_num
                    # TODO support accuracy visualization of multiple top_k
                    plot_accuracy = [
                        batch_accuracy[i][opt.top_k[0]]
                        for i in range(len(batch_accuracy))
                    ]
                    accuracy_list = [item["ratio"] for item in plot_accuracy]
                    webvis.plot_points(x_axis, loss_list, "Loss", "Train")
                    webvis.plot_points(x_axis, accuracy_list, "Accuracy",
                                       "Train")

            # display train data
            if total_batch_iter % opt.display_data_freq == 0:
                image_list = list()
                show_image_num = int(
                    np.ceil(opt.display_image_ratio * inputs.size()[0]))
                for index in range(show_image_num):
                    input_im = util.tensor2im(inputs[index], opt.mean, opt.std)
                    class_label = "Image_" + str(index)
                    if labels is not None:
                        target_ids = [
                            targets[i][index] for i in range(opt.class_num)
                        ]
                        rids = [id2rid[j][k] for j, k in enumerate(target_ids)]
                        class_label += "_"
                        class_label += "#".join(
                            [rid2name[j][k] for j, k in enumerate(rids)])
                    image_list.append((class_label, input_im))
                image_dict = OrderedDict(image_list)
                save_result = total_batch_iter % opt.update_html_freq
                webvis.plot_images(image_dict,
                                   opt.display_id + 2 * opt.class_num, epoch,
                                   save_result)

            # validate and display validate loss and accuracy
            if len(val_set
                   ) > 0 and total_batch_iter % opt.display_validate_freq == 0:
                val_accuracy, val_loss = validate(model, criterion, val_set,
                                                  opt)
                x_axis = epoch + float(epoch_batch_iter) / train_batch_num
                accuracy_list = [
                    val_accuracy[i][opt.top_k[0]]["ratio"]
                    for i in range(len(val_accuracy))
                ]
                util.print_loss(val_loss, "Validate", epoch, total_batch_iter)
                util.print_accuracy(val_accuracy, "Validate", epoch,
                                    total_batch_iter)
                if opt.display_id > 0:
                    webvis.plot_points(x_axis, val_loss, "Loss", "Validate")
                    webvis.plot_points(x_axis, accuracy_list, "Accuracy",
                                       "Validate")

            # save snapshot
            if total_batch_iter % opt.save_batch_iter_freq == 0:
                logging.info(
                    "saving the latest model (epoch %d, total_batch_iter %d)" %
                    (epoch, total_batch_iter))
                save_model(model, opt, epoch)
                # TODO snapshot loss and accuracy

        logging.info('End of epoch %d / %d \t Time Taken: %d sec' %
                     (epoch, opt.sum_epoch, time.time() - epoch_start_t))

        if epoch % opt.save_epoch_freq == 0:
            logging.info('saving the model at the end of epoch %d, iters %d' %
                         (epoch + 1, total_batch_iter))
            save_model(model, opt, epoch + 1)

        # adjust learning rate
        #scheduler.step()
        #lr = optimizer.param_groups[0]['lr']
        #logging.info('learning rate = %.7f epoch = %d' %(lr,epoch))
    logging.info("--------Optimization Done--------")
def train(model, criterion, train_set, val_set, opt, labels=None):
    # define web visualizer using visdom
    webvis = WebVisualizer(opt)
    
    # modify learning rate of last layer
    finetune_params = modify_last_layer_lr(model.named_parameters(), 
                                            opt.lr, opt.lr_mult_w, opt.lr_mult_b)
    # define optimizer
    optimizer = optim.SGD(finetune_params, 
                          opt.lr, 
                          momentum=opt.momentum, 
                          weight_decay=opt.weight_decay)
    # define laerning rate scheluer
    scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                          step_size=opt.lr_decay_in_epoch,
                                          gamma=opt.gamma)
    if labels is not None:
        rid2name, id2rid = labels
    
    # record forward and backward times 
    train_batch_num = len(train_set)
    total_batch_iter = 0
    logging.info("####################Train Model###################")
    for epoch in range(opt.sum_epoch):
        epoch_start_t = time.time()
        epoch_batch_iter = 0
        logging.info('Begin of epoch %d' %(epoch))
        for i, data in enumerate(train_set):
            iter_start_t = time.time()
            # train 
            inputs, targets = data
            output, loss, loss_list = forward_batch(model, criterion, inputs, targets, opt, "Train")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
           
            webvis.reset()
            epoch_batch_iter += 1
            total_batch_iter += 1

            # display train loss and accuracy
            if total_batch_iter % opt.display_train_freq == 0:
                # accuracy
                batch_accuracy = calc_accuracy(output, targets, opt.score_thres, opt.top_k) 
                util.print_loss(loss_list, "Train", epoch, total_batch_iter)
                util.print_accuracy(batch_accuracy, "Train", epoch, total_batch_iter)
                if opt.display_id > 0:
                    x_axis = epoch + float(epoch_batch_iter)/train_batch_num
                    # TODO support accuracy visualization of multiple top_k
                    plot_accuracy = [batch_accuracy[i][opt.top_k[0]] for i in range(len(batch_accuracy)) ]
                    accuracy_list = [item["ratio"] for item in plot_accuracy]
                    webvis.plot_points(x_axis, loss_list, "Loss", "Train")
                    webvis.plot_points(x_axis, accuracy_list, "Accuracy", "Train")
            
            # display train data 
            if total_batch_iter % opt.display_data_freq == 0:
                image_list = list()
                show_image_num = int(np.ceil(opt.display_image_ratio * inputs.size()[0]))
                for index in range(show_image_num): 
                    input_im = util.tensor2im(inputs[index], opt.mean, opt.std)
                    class_label = "Image_" + str(index) 
                    if labels is not None:
                        target_ids = [targets[i][index] for i in range(opt.class_num)]
                        rids = [id2rid[j][k] for j,k in enumerate(target_ids)]
                        class_label += "_"
                        class_label += "#".join([rid2name[j][k] for j,k in enumerate(rids)])
                    image_list.append((class_label, input_im))
                image_dict = OrderedDict(image_list)
                save_result = total_batch_iter % opt.update_html_freq
                webvis.plot_images(image_dict, opt.display_id + 2*opt.class_num, epoch, save_result)
            
            # validate and display validate loss and accuracy
            if len(val_set) > 0  and total_batch_iter % opt.display_validate_freq == 0:
                val_accuracy, val_loss = validate(model, criterion, val_set, opt)
                x_axis = epoch + float(epoch_batch_iter)/train_batch_num
                accuracy_list = [val_accuracy[i][opt.top_k[0]]["ratio"] for i in range(len(val_accuracy))]
                util.print_loss(val_loss, "Validate", epoch, total_batch_iter)
                util.print_accuracy(val_accuracy, "Validate", epoch, total_batch_iter)
                if opt.display_id > 0:
                    webvis.plot_points(x_axis, val_loss, "Loss", "Validate")
                    webvis.plot_points(x_axis, accuracy_list, "Accuracy", "Validate")

            # save snapshot 
            if total_batch_iter % opt.save_batch_iter_freq == 0:
                logging.info("saving the latest model (epoch %d, total_batch_iter %d)" %(epoch, total_batch_iter))
                save_model(model, opt, epoch)
                # TODO snapshot loss and accuracy
            
        logging.info('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.sum_epoch, time.time() - epoch_start_t))
        
        if epoch % opt.save_epoch_freq == 0:
            logging.info('saving the model at the end of epoch %d, iters %d' %(epoch+1, total_batch_iter))
            save_model(model, opt, epoch+1) 

        # adjust learning rate 
        scheduler.step()
        lr = optimizer.param_groups[0]['lr'] 
        logging.info('learning rate = %.7f epoch = %d' %(lr,epoch)) 
    logging.info("--------Optimization Done--------")