Exemple #1
0
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'valid ppl {:8.2f}'.format(epoch,
                                         (time.time() - epoch_start_time),
                                         val_loss, math.exp(val_loss)))
        print('-' * 89)
        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            with open(args.logdir + "/" + args.save, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            # Anneal the learning rate if no improvement has been seen in the validation dataset.
            lr /= 4.0
            if args.opt_method == "YF":
                optimizer.set_lr_factor(optimizer.get_lr_factor() / 4.0)
            else:
                for group in optimizer.param_groups:
                    group['lr'] /= 4.0
        if args.opt_method == "YF":
            mu_list.append(optimizer._mu)
            lr_list.append(optimizer._lr)
        with open(args.logdir+"/loss.txt", "wb") as f:
            np.savetxt(f, np.array(train_loss_list) )
        with open(args.logdir+"/val_loss.txt", "wb") as f:
            np.savetxt(f, np.array(val_loss_list) )
        with open(args.logdir+"/lr.txt", "wb") as f:
            np.savetxt(f, np.array(lr_list) )
        with open(args.logdir+"/mu.txt", "wb") as f:
            np.savetxt(f, np.array(mu_list) )
    val_loss = evaluate()
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
            'valid ppl {:8.2f} |'.format(epoch, (time.time() - epoch_start_time),
                                       val_loss, math.exp(val_loss)))
    print('-' * 89)

    # Save the model if the validation loss is the best we've seen so far.
    if not best_val_loss or val_loss < best_val_loss:
        with open(args.save, 'wb') as f:
            torch.save(model, f)
        best_val_loss = val_loss
        lr_decay=1.15
        lr /=  lr_decay
        if args.opt_method == "YF":
            optimizer.set_lr_factor(optimizer.get_lr_factor() / lr_decay)
        else:
            for group in optimizer.param_groups:
                group['lr'] /= lr_decay
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        #lr /= 3
        lr_decay=4
        lr /=  lr_decay
        if args.opt_method == "YF":
            optimizer.set_lr_factor(optimizer.get_lr_factor() / lr_decay)
        else:
            for group in optimizer.param_groups:
                group['lr'] /= lr_decay